mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:04:58 +08:00
[6/N] pass whole config to inner model (#10205)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
f0f2e5638e
commit
f89d18ff74
@ -34,7 +34,8 @@ from vllm.transformers_utils.configs.arctic import ArcticConfig
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -364,14 +365,13 @@ class ArcticDecoderLayer(nn.Module):
|
||||
@support_torch_compile
|
||||
class ArcticModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ArcticConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
@ -418,13 +418,10 @@ class ArcticForCausalLM(nn.Module, SupportsPP):
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.model = ArcticModel(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix=prefix)
|
||||
self.model = ArcticModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.vocab_size,
|
||||
|
||||
@ -253,13 +253,18 @@ class BaiChuanDecoderLayer(nn.Module):
|
||||
@support_torch_compile
|
||||
class BaiChuanModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
position_embedding: str,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
position_embedding: str = "ROPE",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
@ -332,21 +337,22 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
position_embedding: str = "ROPE",
|
||||
):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.model = BaiChuanModel(config, position_embedding, cache_config,
|
||||
quant_config)
|
||||
self.model = BaiChuanModel(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
position_embedding=position_embedding)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
@ -438,16 +444,16 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
|
||||
NOTE: the class name has a lower case 'c'.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
if config.hidden_size == 4096: # baichuan2 7b
|
||||
super().__init__(vllm_config, prefix, "ROPE")
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
position_embedding="ROPE")
|
||||
else: # baichuan 13b, baichuan2 13b
|
||||
super().__init__(vllm_config, prefix, "ALIBI")
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
position_embedding="ALIBI")
|
||||
|
||||
|
||||
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
|
||||
@ -455,9 +461,7 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
|
||||
NOTE: the class name has an upper case 'C'.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__(vllm_config, prefix, "ROPE")
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
position_embedding="ROPE")
|
||||
|
||||
@ -41,6 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .utils import maybe_prefix
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@ -739,13 +741,14 @@ class BartModel(nn.Module):
|
||||
"encoder.embed_tokens.weight", "decoder.embed_tokens.weight"
|
||||
]
|
||||
|
||||
def __init__(self,
|
||||
config: BartConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
|
||||
self.padding_idx = config.pad_token_id
|
||||
@ -810,20 +813,16 @@ class BartModel(nn.Module):
|
||||
class BartForConditionalGeneration(nn.Module):
|
||||
base_model_prefix = "model"
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
# currently all existing BART models have `tie_word_embeddings` enabled
|
||||
assert config.tie_word_embeddings
|
||||
self.config = config
|
||||
self.model = BartModel(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
lora_config=lora_config)
|
||||
self.model = BartModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
|
||||
@ -21,6 +21,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
|
||||
from .utils import maybe_prefix
|
||||
|
||||
|
||||
class BertEmbedding(nn.Module):
|
||||
|
||||
@ -309,12 +311,13 @@ class BertOutput(nn.Module):
|
||||
|
||||
class BertModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: BertConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.embeddings = BertEmbedding(config)
|
||||
self.encoder = BertEncoder(config,
|
||||
cache_config,
|
||||
@ -382,17 +385,11 @@ class BertEmbeddingModel(nn.Module):
|
||||
_pooler: An instance of Pooler used for pooling operations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
self.model = BertModel(config, cache_config, quant_config)
|
||||
self.model = BertModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.CLS,
|
||||
|
||||
@ -23,7 +23,7 @@ from .blip import (BlipVisionModel, dummy_image_for_blip,
|
||||
get_max_blip_image_tokens)
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||
merge_multimodal_embeddings)
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
|
||||
# We use this internally as placeholders since there is no image token
|
||||
# defined on the HuggingFace repo
|
||||
@ -483,11 +483,7 @@ def input_processor_for_blip2(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_blip2)
|
||||
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
@ -517,7 +513,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self.language_model = init_vllm_registered_model(
|
||||
config.text_config,
|
||||
vllm_config=vllm_config,
|
||||
prefix="language_model")
|
||||
prefix=maybe_prefix(prefix, "language_model"))
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
@ -42,7 +42,8 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
||||
@ -221,14 +222,13 @@ class BloomBlock(nn.Module):
|
||||
@support_torch_compile
|
||||
class BloomModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BloomConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
# Embedding + LN Embedding
|
||||
@ -288,11 +288,12 @@ class BloomForCausalLM(nn.Module, SupportsPP):
|
||||
):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.transformer = BloomModel(config, cache_config, quant_config)
|
||||
self.transformer = BloomModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "transformer"))
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head = self.transformer.word_embeddings
|
||||
else:
|
||||
|
||||
@ -37,7 +37,8 @@ from vllm.utils import print_warning_once
|
||||
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
# These configs are not part of the model config but the preprocessor
|
||||
# and processor files, so we hardcode them in the model file for now.
|
||||
@ -831,14 +832,13 @@ class ChameleonImageVocabularyMapping:
|
||||
|
||||
class ChameleonModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ChameleonConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
@ -924,19 +924,14 @@ class ChameleonModel(nn.Module):
|
||||
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.model = ChameleonModel(config, cache_config, quant_config)
|
||||
self.model = ChameleonModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.unpadded_vocab_size,
|
||||
|
||||
@ -39,7 +39,8 @@ from vllm.transformers_utils.configs import ChatGLMConfig
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -481,14 +482,13 @@ class GLMTransformer(nn.Module):
|
||||
|
||||
class ChatGLMModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ChatGLMConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config = config
|
||||
|
||||
self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
|
||||
@ -600,7 +600,6 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
|
||||
):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
@ -611,7 +610,9 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
|
||||
self.quant_config = quant_config
|
||||
self.max_position_embeddings = getattr(config, "max_sequence_length",
|
||||
8192)
|
||||
self.transformer = ChatGLMModel(config, cache_config, quant_config)
|
||||
self.transformer = ChatGLMModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "transformer"))
|
||||
if self.config.tie_word_embeddings:
|
||||
self.transformer.output_layer.weight = (
|
||||
self.transformer.embedding.weight)
|
||||
|
||||
@ -28,7 +28,7 @@ from transformers import CohereConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
@ -49,7 +49,8 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
@torch.compile
|
||||
@ -253,15 +254,14 @@ class CohereDecoderLayer(nn.Module):
|
||||
@support_torch_compile
|
||||
class CohereModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: CohereConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
lora_vocab = (lora_config.lora_extra_vocab_size *
|
||||
(lora_config.max_loras or 1)) if lora_config else 0
|
||||
@ -332,14 +332,9 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
embedding_modules = {"embed_tokens": "input_embeddings"}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
self.config = config
|
||||
@ -353,10 +348,8 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size,
|
||||
scale=config.logit_scale)
|
||||
self.model = CohereModel(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
lora_config=lora_config)
|
||||
self.model = CohereModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.sampler = get_sampler()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
@ -25,7 +25,8 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
class DbrxRouter(nn.Module):
|
||||
@ -294,14 +295,13 @@ class DbrxBlock(nn.Module):
|
||||
|
||||
class DbrxModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DbrxConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.wte = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.d_model,
|
||||
@ -357,7 +357,6 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
|
||||
):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
if config.tie_word_embeddings:
|
||||
@ -365,7 +364,9 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
|
||||
"tie_word_embeddings is not supported for Dbrx models.")
|
||||
self.quant_config = quant_config
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
self.transformer = DbrxModel(config, cache_config, quant_config)
|
||||
self.transformer = DbrxModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "transformer"))
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.d_model,
|
||||
|
||||
@ -51,11 +51,7 @@ class DeciLMForCausalLM(LlamaForCausalLM):
|
||||
instead.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
|
||||
delattr(config, "num_key_value_heads_per_layer")
|
||||
|
||||
@ -50,7 +50,8 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
class DeepseekMLP(nn.Module):
|
||||
@ -326,14 +327,13 @@ class DeepseekModel(nn.Module):
|
||||
|
||||
fall_back_to_pt_during_load = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
@ -383,18 +383,14 @@ class DeepseekModel(nn.Module):
|
||||
|
||||
class DeepseekForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = DeepseekModel(config, cache_config, quant_config)
|
||||
self.model = DeepseekModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
|
||||
@ -51,7 +51,8 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
class DeepseekV2MLP(nn.Module):
|
||||
@ -408,14 +409,13 @@ class DeepseekV2Model(nn.Module):
|
||||
|
||||
fall_back_to_pt_during_load = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
@ -479,21 +479,14 @@ class DeepseekV2Model(nn.Module):
|
||||
|
||||
class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = DeepseekV2Model(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix="model")
|
||||
self.model = DeepseekV2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
|
||||
@ -14,6 +14,8 @@ from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .utils import maybe_prefix
|
||||
|
||||
|
||||
class EAGLE(nn.Module):
|
||||
"""This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077
|
||||
@ -42,7 +44,8 @@ class EAGLE(nn.Module):
|
||||
architectures = getattr(self.config.model, "architectures", [])
|
||||
model_cls, _ = ModelRegistry.resolve_model_cls(architectures)
|
||||
|
||||
self.model = model_cls(vllm_config, prefix)
|
||||
self.model = model_cls(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.fc = nn.Linear(config.model.hidden_size * 2,
|
||||
config.model.hidden_size,
|
||||
bias=getattr(self.config, "eagle_fc_bias", False))
|
||||
|
||||
@ -29,7 +29,7 @@ from torch import nn
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
@ -54,7 +54,8 @@ from vllm.transformers_utils.configs.exaone import ExaoneConfig
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
class ExaoneGatedMLP(nn.Module):
|
||||
@ -314,15 +315,14 @@ class ExaoneDecoderLayer(nn.Module):
|
||||
@support_torch_compile
|
||||
class ExaoneModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ExaoneConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
lora_vocab = ((lora_config.lora_extra_vocab_size *
|
||||
@ -438,14 +438,9 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
"c_fc_1": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
@ -453,11 +448,8 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.transformer = ExaoneModel(
|
||||
config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
lora_config=lora_config,
|
||||
prefix="model",
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"),
|
||||
)
|
||||
if get_pp_group().is_last_rank:
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
|
||||
@ -48,7 +48,8 @@ from vllm.transformers_utils.configs import RWConfig
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
FalconConfig = Union[HF_FalconConfig, RWConfig]
|
||||
|
||||
@ -332,14 +333,13 @@ class FalconDecoderLayer(nn.Module):
|
||||
@support_torch_compile
|
||||
class FalconModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: FalconConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
@ -408,11 +408,12 @@ class FalconForCausalLM(nn.Module, SupportsPP):
|
||||
):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.transformer = FalconModel(config, cache_config, quant_config)
|
||||
self.transformer = FalconModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "transformer"))
|
||||
# only Falcon-11B doesn't share lm_head weight with word embeddings
|
||||
# and previous Falcon model doesn't have tie_word_embeddings config
|
||||
# so we set tie_word_embeddings to True by default
|
||||
|
||||
@ -3,13 +3,10 @@ from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.bart import (BartDecoder, BartEncoder,
|
||||
@ -23,11 +20,13 @@ from .utils import AutoWeightsLoader
|
||||
|
||||
class Florence2LanguageModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config = config
|
||||
|
||||
self.padding_idx = config.pad_token_id
|
||||
@ -93,15 +92,14 @@ class Florence2LanguageModel(nn.Module):
|
||||
|
||||
class Florence2LanguageForConditionalGeneration(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
self.config = config
|
||||
self.model = Florence2LanguageModel(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
self.model = Florence2LanguageModel(vllm_config=vllm_config,
|
||||
prefix=prefix)
|
||||
embed_scale = math.sqrt(
|
||||
config.d_model) if config.scale_embedding else 1.0
|
||||
|
||||
@ -189,17 +187,15 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
|
||||
|
||||
class Florence2ForConditionalGeneration(nn.Module):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
# TODO(Isotr0py): Add vision backbone
|
||||
self.language_model = Florence2LanguageForConditionalGeneration(
|
||||
config=config.text_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
vllm_config=vllm_config.with_hf_config(config.text_config),
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
@property
|
||||
def sampler(self):
|
||||
|
||||
@ -258,14 +258,13 @@ class GemmaDecoderLayer(nn.Module):
|
||||
@support_torch_compile
|
||||
class GemmaModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GemmaConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config = config
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
@ -372,14 +371,9 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
@ -389,9 +383,7 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.model = GemmaModel(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
self.model = GemmaModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = get_sampler()
|
||||
|
||||
@ -43,7 +43,8 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -243,11 +244,7 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
@support_torch_compile
|
||||
class Gemma2Model(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
@ -399,13 +396,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
del lora_config # Unused.
|
||||
@ -414,7 +406,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
# currently all existing Gemma models have `tie_word_embeddings` enabled
|
||||
assert config.tie_word_embeddings
|
||||
self.quant_config = quant_config
|
||||
self.model = Gemma2Model(config, cache_config, quant_config)
|
||||
self.model = Gemma2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.logits_processor = LogitsProcessor(
|
||||
config.vocab_size, soft_cap=config.final_logit_softcapping)
|
||||
self.sampler = get_sampler()
|
||||
@ -471,14 +464,11 @@ class Gemma2EmbeddingModel(nn.Module, SupportsPP):
|
||||
_pooler: An instance of Pooler used for pooling operations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
self.model = Gemma2Model(vllm_config, prefix)
|
||||
self.model = Gemma2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
vllm_config.model_config.pooler_config,
|
||||
pooling_type=PoolingType.LAST,
|
||||
|
||||
@ -42,7 +42,8 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
class GPT2Attention(nn.Module):
|
||||
@ -184,14 +185,13 @@ class GPT2Block(nn.Module):
|
||||
@support_torch_compile
|
||||
class GPT2Model(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GPT2Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config = config
|
||||
assert not config.add_cross_attention
|
||||
assert not config.scale_attn_by_inverse_layer_idx
|
||||
@ -247,14 +247,12 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
|
||||
):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.transformer = GPT2Model(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix="transformer")
|
||||
self.transformer = GPT2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "transformer"))
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head = self.transformer.wte
|
||||
else:
|
||||
|
||||
@ -25,7 +25,7 @@ from transformers import GPTBigCodeConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@ -189,15 +189,14 @@ class GPTBigCodeBlock(nn.Module):
|
||||
@support_torch_compile
|
||||
class GPTBigCodeModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTBigCodeConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
assert not config.add_cross_attention
|
||||
|
||||
@ -265,7 +264,6 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
@ -273,8 +271,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.transformer = GPTBigCodeModel(config, cache_config, quant_config,
|
||||
lora_config)
|
||||
self.transformer = GPTBigCodeModel(vllm_config=vllm_config,
|
||||
prefix=prefix)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head = self.transformer.wte
|
||||
else:
|
||||
|
||||
@ -42,7 +42,8 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
class GPTJAttention(nn.Module):
|
||||
@ -177,14 +178,13 @@ class GPTJBlock(nn.Module):
|
||||
@support_torch_compile
|
||||
class GPTJModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTJConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config = config
|
||||
self.embed_dim = config.n_embd
|
||||
self.wte = VocabParallelEmbedding(
|
||||
@ -236,12 +236,13 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
|
||||
):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
assert not config.tie_word_embeddings
|
||||
self.transformer = GPTJModel(config, cache_config, quant_config)
|
||||
self.transformer = GPTJModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "transformer"))
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.n_embd,
|
||||
|
||||
@ -41,7 +41,8 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
class GPTNeoXAttention(nn.Module):
|
||||
@ -189,14 +190,13 @@ class GPTNeoXLayer(nn.Module):
|
||||
@support_torch_compile
|
||||
class GPTNeoXModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTNeoXConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config = config
|
||||
|
||||
self.embed_in = VocabParallelEmbedding(
|
||||
@ -249,11 +249,11 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
|
||||
):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.gpt_neox = GPTNeoXModel(config, cache_config, quant_config)
|
||||
self.gpt_neox = GPTNeoXModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "gpt_neox"))
|
||||
self.embed_out = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
|
||||
@ -28,7 +28,7 @@ from transformers import GraniteConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
@ -52,7 +52,8 @@ from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
|
||||
from .utils import (PPMissingLayer, is_pp_missing_parameter, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
class GraniteMLP(nn.Module):
|
||||
@ -257,15 +258,14 @@ class GraniteDecoderLayer(nn.Module):
|
||||
@support_torch_compile
|
||||
class GraniteModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GraniteConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
lora_vocab = (lora_config.lora_extra_vocab_size *
|
||||
@ -370,25 +370,17 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.model = GraniteModel(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
lora_config=lora_config,
|
||||
prefix="model")
|
||||
self.model = GraniteModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
if get_pp_group().is_last_rank:
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
|
||||
@ -28,7 +28,7 @@ from transformers.models.granitemoe import GraniteMoeConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -47,7 +47,7 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
from . import mixtral
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import make_layers
|
||||
from .utils import make_layers, maybe_prefix
|
||||
|
||||
|
||||
class GraniteMoeMoE(nn.Module):
|
||||
@ -247,15 +247,14 @@ class GraniteMoeDecoderLayer(nn.Module):
|
||||
@support_torch_compile
|
||||
class GraniteMoeModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GraniteMoeConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.padding_idx = config.pad_token_id
|
||||
lora_vocab = (lora_config.lora_extra_vocab_size *
|
||||
(lora_config.max_loras or 1)) if lora_config else 0
|
||||
@ -333,25 +332,17 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
}
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.model = GraniteMoeModel(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
lora_config=lora_config,
|
||||
prefix="model")
|
||||
self.model = GraniteMoeModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
|
||||
@ -22,17 +22,15 @@ import torch.utils.checkpoint
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
# Temporary solution for transformers below 4.46.0.
|
||||
from transformers import PretrainedConfig as Idefics3Config
|
||||
from transformers import ProcessorMixin as Idefics3ImageProcessor
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext, token_inputs)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
@ -48,7 +46,8 @@ from .idefics2_vision_model import (
|
||||
# yapf: enable
|
||||
from .interfaces import SupportsMultiModal
|
||||
from .llama import LlamaModel
|
||||
from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -417,13 +416,13 @@ class Idefics3Connector(nn.Module):
|
||||
|
||||
class Idefics3Model(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Idefics3Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config = config
|
||||
self.padding_idx = self.config.text_config.pad_token_id
|
||||
self.vocab_size = self.config.text_config.vocab_size
|
||||
@ -613,22 +612,18 @@ class Idefics3Model(nn.Module):
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_idefics3)
|
||||
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
self.model = Idefics3Model(config, cache_config, quant_config)
|
||||
self.model = Idefics3Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.image_token_id = self.config.image_token_id
|
||||
|
||||
self.lm_head = ParallelLMHead(
|
||||
|
||||
@ -250,14 +250,13 @@ class InternLMDecoderLayer(nn.Module):
|
||||
@support_torch_compile
|
||||
class InternLM2Model(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
@ -317,20 +316,13 @@ class InternLM2Model(nn.Module):
|
||||
|
||||
class InternLM2ForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = InternLM2Model(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
self.model = InternLM2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.output = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
|
||||
@ -104,14 +104,13 @@ class InternLM2VEDecoderLayer(nn.Module):
|
||||
|
||||
class InternLM2VEModel(InternLM2Model):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(config, cache_config, quant_config)
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: InternLM2VEDecoderLayer(
|
||||
@ -159,12 +158,8 @@ class InternLM2VEModel(InternLM2Model):
|
||||
|
||||
class InternLM2VEForCausalLM(InternLM2ForCausalLM):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(vllm_config, prefix=prefix)
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
|
||||
@ -35,7 +35,7 @@ from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
||||
get_clip_num_patches)
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
merge_multimodal_embeddings)
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
|
||||
IMG_START = '<img>'
|
||||
IMG_END = '</img>'
|
||||
@ -435,13 +435,13 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
is_mono=self.is_mono,
|
||||
prefix="vision_model",
|
||||
prefix=maybe_prefix(prefix, "vision_model"),
|
||||
)
|
||||
|
||||
self.language_model = init_vllm_registered_model(
|
||||
config.text_config,
|
||||
vllm_config=vllm_config,
|
||||
prefix="language_model")
|
||||
prefix=maybe_prefix(prefix, "language_model"))
|
||||
|
||||
self.mlp1 = self._init_mlp1(config)
|
||||
|
||||
|
||||
@ -44,7 +44,8 @@ from vllm.transformers_utils.configs import JAISConfig
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
class SwiGLUActivation(nn.Module):
|
||||
@ -215,14 +216,13 @@ class JAISBlock(nn.Module):
|
||||
@support_torch_compile
|
||||
class JAISModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: JAISConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config = config
|
||||
assert not config.add_cross_attention
|
||||
assert not config.scale_attn_by_inverse_layer_idx
|
||||
@ -293,11 +293,12 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
|
||||
):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.transformer = JAISModel(config, cache_config, quant_config)
|
||||
self.transformer = JAISModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "transformer"))
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head = self.transformer.wte
|
||||
else:
|
||||
|
||||
@ -7,7 +7,7 @@ from transformers import JambaConfig
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -29,6 +29,7 @@ from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
|
||||
_get_graph_batch_size)
|
||||
|
||||
from .interfaces import HasInnerState, SupportsLoRA
|
||||
from .utils import maybe_prefix
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
@ -258,14 +259,14 @@ ALL_DECODER_LAYER_TYPES = {
|
||||
|
||||
class JambaModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: JambaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
lora_vocab = ((lora_config.lora_extra_vocab_size *
|
||||
@ -348,14 +349,9 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
}
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
assert not cache_config.enable_prefix_caching, \
|
||||
@ -364,10 +360,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.model = JambaModel(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
lora_config=lora_config)
|
||||
self.model = JambaModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
|
||||
@ -28,7 +28,7 @@ from transformers import LlamaConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
@ -271,15 +271,14 @@ class LlamaDecoderLayer(nn.Module):
|
||||
@support_torch_compile
|
||||
class LlamaModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
lora_vocab = (lora_config.lora_extra_vocab_size *
|
||||
@ -492,24 +491,16 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
"norm": "model.norm"
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.model = LlamaModel(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
lora_config=lora_config,
|
||||
self.model = LlamaModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
if get_pp_group().is_last_rank:
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
@ -652,23 +643,12 @@ class LlamaEmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
|
||||
self.model = LlamaModel(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
lora_config,
|
||||
self.model = LlamaModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
|
||||
@ -32,7 +32,7 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
|
||||
input_processor_for_siglip)
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
merge_multimodal_embeddings)
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
|
||||
|
||||
class LlavaImagePixelInputs(TypedDict):
|
||||
@ -282,7 +282,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
config,
|
||||
quant_config,
|
||||
require_post_norm=False,
|
||||
prefix="vision_tower")
|
||||
prefix=maybe_prefix(prefix, "vision_tower"))
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(
|
||||
vision_hidden_size=config.vision_config.hidden_size,
|
||||
text_hidden_size=config.text_config.hidden_size,
|
||||
@ -291,7 +291,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self.language_model = init_vllm_registered_model(
|
||||
config.text_config,
|
||||
vllm_config=vllm_config,
|
||||
prefix="language_model")
|
||||
prefix=maybe_prefix(prefix, "language_model"))
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
@ -31,7 +31,7 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||
dummy_seq_data_for_siglip, get_siglip_image_feature_size,
|
||||
get_siglip_patch_grid_length, input_processor_for_siglip)
|
||||
from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
|
||||
init_vllm_registered_model)
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
|
||||
|
||||
class LlavaNextImagePixelInputs(TypedDict):
|
||||
@ -296,7 +296,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
config,
|
||||
quant_config,
|
||||
require_post_norm=False,
|
||||
prefix="vision_tower")
|
||||
prefix=maybe_prefix(prefix, "vision_tower"))
|
||||
self.image_newline = nn.Parameter(
|
||||
torch.empty(config.text_config.hidden_size))
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(
|
||||
@ -307,7 +307,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.language_model = init_vllm_registered_model(
|
||||
config.text_config,
|
||||
vllm_config=vllm_config,
|
||||
prefix="language_model")
|
||||
prefix=maybe_prefix(prefix, "language_model"))
|
||||
|
||||
# The same model class supports both language generation and embedding
|
||||
# because the architecture name is the same
|
||||
|
||||
@ -29,7 +29,7 @@ from .llava import init_vision_tower_for_llava
|
||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||
dummy_seq_data_for_siglip)
|
||||
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||
merge_multimodal_embeddings)
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
|
||||
# For profile run
|
||||
_MAX_FRAMES_PER_VIDEO = 32
|
||||
@ -267,7 +267,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
config,
|
||||
quant_config,
|
||||
require_post_norm=False,
|
||||
prefix="vision_tower")
|
||||
prefix=maybe_prefix(prefix, "vision_tower"))
|
||||
self.vision_resampler = LlavaNextVideoPooler(config)
|
||||
self.multi_modal_projector = LlavaNextMultiModalProjector(
|
||||
vision_hidden_size=config.vision_config.hidden_size,
|
||||
@ -276,7 +276,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.language_model = init_vllm_registered_model(
|
||||
config.text_config,
|
||||
vllm_config=vllm_config,
|
||||
prefix="language_model")
|
||||
prefix=maybe_prefix(prefix, "language_model"))
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.model.make_empty_intermediate_tensors)
|
||||
|
||||
@ -35,7 +35,7 @@ from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip,
|
||||
dummy_video_for_siglip, get_siglip_image_feature_size,
|
||||
get_siglip_patch_grid_length, input_processor_for_siglip)
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
merge_multimodal_embeddings)
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
|
||||
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
|
||||
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
|
||||
@ -418,12 +418,12 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
config,
|
||||
quant_config,
|
||||
require_post_norm=False,
|
||||
prefix="vision_tower")
|
||||
prefix=maybe_prefix(prefix, "vision_tower"))
|
||||
self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config)
|
||||
self.language_model = init_vllm_registered_model(
|
||||
config.text_config,
|
||||
vllm_config=vllm_config,
|
||||
prefix="language_model")
|
||||
prefix=maybe_prefix(prefix, "language_model"))
|
||||
self.image_newline = nn.Parameter(
|
||||
torch.empty(config.text_config.hidden_size))
|
||||
|
||||
|
||||
@ -6,7 +6,7 @@ from torch import nn
|
||||
from transformers import MambaConfig
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
@ -26,6 +26,8 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
|
||||
_get_graph_batch_size)
|
||||
|
||||
from .utils import maybe_prefix
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
@ -73,14 +75,14 @@ class MambaDecoderLayer(nn.Module):
|
||||
|
||||
class MambaModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MambaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
lora_vocab = ((lora_config.lora_extra_vocab_size *
|
||||
@ -130,14 +132,9 @@ class MambaModel(nn.Module):
|
||||
|
||||
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
assert not cache_config.enable_prefix_caching, \
|
||||
@ -146,10 +143,8 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.backbone = MambaModel(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
lora_config=lora_config)
|
||||
self.backbone = MambaModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "backbone"))
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
|
||||
@ -29,7 +29,7 @@ from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
@ -53,7 +53,8 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
class MiniCPMMoE(nn.Module):
|
||||
@ -351,15 +352,14 @@ class MiniCPMDecoderLayer(nn.Module):
|
||||
@support_torch_compile
|
||||
class MiniCPMModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
self.cache_config = cache_config
|
||||
self.quant_config = quant_config
|
||||
@ -461,24 +461,22 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
}
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.prefix = prefix
|
||||
self.vllm_config = vllm_config
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
self.cache_config = cache_config
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.num_experts = getattr(self.config, "num_experts", 0)
|
||||
self._init_model()
|
||||
self._init_model(vllm_config=vllm_config, prefix=prefix)
|
||||
unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
@ -502,11 +500,9 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def _init_model(self):
|
||||
self.model = MiniCPMModel(config=self.config,
|
||||
cache_config=self.cache_config,
|
||||
quant_config=self.quant_config,
|
||||
lora_config=self.lora_config)
|
||||
def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
self.model = MiniCPMModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -28,7 +28,7 @@ from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@ -40,7 +40,7 @@ from vllm.model_executor.models.minicpm import (MiniCPMDecoderLayer,
|
||||
MiniCPMForCausalLM,
|
||||
MiniCPMModel)
|
||||
|
||||
from .utils import make_layers
|
||||
from .utils import make_layers, maybe_prefix
|
||||
|
||||
|
||||
class MiniCPM3Attention(nn.Module):
|
||||
@ -238,8 +238,6 @@ class MiniCPM3ForCausalLM(MiniCPMForCausalLM):
|
||||
# `embedding_modules` and `embedding_padding_modules`
|
||||
# are inherited from MiniCPMForCausalLM
|
||||
|
||||
def _init_model(self):
|
||||
self.model = MiniCPM3Model(config=self.config,
|
||||
cache_config=self.cache_config,
|
||||
quant_config=self.quant_config,
|
||||
lora_config=self.lora_config)
|
||||
def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
self.model = MiniCPM3Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
@ -34,7 +34,7 @@ from transformers import PretrainedConfig
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext, token_inputs)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
@ -59,7 +59,7 @@ from vllm.sequence import IntermediateTensors, SequenceData
|
||||
|
||||
from .idefics2_vision_model import Idefics2VisionTransformer
|
||||
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
|
||||
from .utils import is_pp_missing_parameter
|
||||
from .utils import is_pp_missing_parameter, maybe_prefix
|
||||
|
||||
_KEYS_TO_MODIFY_MAPPING = {
|
||||
"llm.lm_head": "lm_head",
|
||||
@ -390,7 +390,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
):
|
||||
config = vllm_config.model_config.hf_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
super().__init__()
|
||||
# All MiniCPM-V models disable `tie_word_embeddings` but
|
||||
@ -401,11 +400,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
self.version = get_version_by_config(self.config)
|
||||
self.llm = self.init_llm(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix="llm")
|
||||
self.vpm = self.init_vision_module(config, quant_config, prefix="vpm")
|
||||
self.llm = self.init_llm(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "llm"))
|
||||
self.vpm = self.init_vision_module(config,
|
||||
quant_config,
|
||||
prefix=maybe_prefix(prefix, "vpm"))
|
||||
param_dtype = torch.get_default_dtype()
|
||||
self.vpm.to(dtype=param_dtype)
|
||||
self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else
|
||||
@ -414,13 +413,15 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self.resampler = self.init_resampler(self.embed_dim,
|
||||
self.vision_dim,
|
||||
quant_config=quant_config,
|
||||
prefix="resampler")
|
||||
prefix=maybe_prefix(
|
||||
prefix, "resampler"))
|
||||
self.resampler.to(device="cuda", dtype=param_dtype)
|
||||
# TODO: why is there _KEYS_TO_MODIFY_MAPPING? lm_head should be in llm
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix="llm.lm_head")
|
||||
prefix=maybe_prefix(
|
||||
prefix, "llm.lm_head"))
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = get_sampler()
|
||||
|
||||
@ -661,9 +662,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def init_llm(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> nn.Module:
|
||||
raise NotImplementedError
|
||||
@ -711,16 +710,10 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
|
||||
|
||||
def init_llm(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> nn.Module:
|
||||
|
||||
return LLMWrapper(MiniCPMModel(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix),
|
||||
return LLMWrapper(MiniCPMModel(vllm_config=vllm_config, prefix=prefix),
|
||||
name="model")
|
||||
|
||||
def init_vision_module(
|
||||
@ -875,15 +868,10 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
|
||||
|
||||
def init_llm(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> nn.Module:
|
||||
return LLMWrapper(LlamaModel(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix),
|
||||
return LLMWrapper(LlamaModel(vllm_config=vllm_config, prefix=prefix),
|
||||
name="model")
|
||||
|
||||
def init_vision_module(
|
||||
@ -1022,16 +1010,10 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
|
||||
|
||||
def init_llm(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> nn.Module:
|
||||
|
||||
return LLMWrapper(Qwen2Model(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix),
|
||||
return LLMWrapper(Qwen2Model(vllm_config=vllm_config, prefix=prefix),
|
||||
name="model")
|
||||
|
||||
def init_vision_module(
|
||||
@ -1151,4 +1133,4 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsLoRA):
|
||||
if instance_class is None:
|
||||
raise ValueError(
|
||||
"Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6")
|
||||
return instance_class(vllm_config, prefix=prefix)
|
||||
return instance_class(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
@ -28,7 +28,7 @@ from transformers import MixtralConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -48,7 +48,8 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
class MixtralMoE(nn.Module):
|
||||
@ -248,15 +249,14 @@ class MixtralDecoderLayer(nn.Module):
|
||||
@support_torch_compile
|
||||
class MixtralModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MixtralConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.padding_idx = config.pad_token_id
|
||||
lora_vocab = (lora_config.lora_extra_vocab_size *
|
||||
(lora_config.max_loras or 1)) if lora_config else 0
|
||||
@ -332,24 +332,16 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
}
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.model = MixtralModel(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
lora_config=lora_config,
|
||||
prefix="model")
|
||||
self.model = MixtralModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
|
||||
@ -49,7 +49,8 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
class MixtralMLP(nn.Module):
|
||||
@ -293,14 +294,13 @@ class MixtralDecoderLayer(nn.Module):
|
||||
|
||||
class MixtralModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MixtralConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
@ -350,18 +350,14 @@ class MixtralModel(nn.Module):
|
||||
class MixtralForCausalLM(nn.Module, SupportsPP):
|
||||
fall_back_to_pt_during_load = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = MixtralModel(config, cache_config, quant_config)
|
||||
self.model = MixtralModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
|
||||
@ -33,7 +33,7 @@ from transformers.models.mllama.processing_mllama import (
|
||||
import vllm.distributed.parallel_state as ps
|
||||
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import (INPUT_REGISTRY, DummyData, EncoderDecoderInputs,
|
||||
InputContext, TokenInputs, token_inputs)
|
||||
@ -56,6 +56,7 @@ from vllm.utils import is_list_of
|
||||
from .clip import CLIPMLP
|
||||
from .interfaces import SupportsMultiModal
|
||||
from .llama import LlamaDecoderLayer, LlamaMLP
|
||||
from .utils import maybe_prefix
|
||||
|
||||
logger = init_logger(__name__)
|
||||
MLLAMA_IMAGE_TOKEN_ID = 128256
|
||||
@ -939,15 +940,13 @@ class MllamaTextModel(nn.Module):
|
||||
config_class = config_mllama.MllamaTextConfig
|
||||
base_model_prefix = "model"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: config_mllama.MllamaTextConfig,
|
||||
cache_config: Optional[CacheConfig],
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config.text_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size + 8,
|
||||
@ -1029,18 +1028,14 @@ class MllamaForCausalLM(nn.Module):
|
||||
"MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer"
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: config_mllama.MllamaTextConfig,
|
||||
cache_config: Optional[CacheConfig],
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config.text_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
self.model = MllamaTextModel(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
self.model = MllamaTextModel(vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.model")
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
@ -1108,14 +1103,9 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
self.hidden_size = config.text_config.hidden_size
|
||||
@ -1127,12 +1117,11 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
|
||||
self.vision_model = MllamaVisionModel(config.vision_config,
|
||||
quant_config,
|
||||
prefix="vision_model")
|
||||
prefix=maybe_prefix(
|
||||
prefix, "vision_model"))
|
||||
self.language_model = MllamaForCausalLM(
|
||||
config.text_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix="language_model",
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
self.multi_modal_projector = ColumnParallelLinear(
|
||||
config.vision_config.vision_output_dim,
|
||||
@ -1140,7 +1129,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
gather_output=True,
|
||||
prefix="multi_modal_projector",
|
||||
prefix=maybe_prefix(prefix, "multi_modal_projector"),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config.output_hidden_states,
|
||||
config.text_config.vocab_size)
|
||||
|
||||
@ -44,7 +44,8 @@ from vllm.transformers_utils.processor import get_processor
|
||||
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (get_vit_attn_backend,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
# TODO: hard-coded for now. Consider making it configurable.
|
||||
VIT_LAYERS = [-2, -9]
|
||||
@ -716,14 +717,13 @@ class MolmoVisionBackbone(nn.Module):
|
||||
@support_torch_compile
|
||||
class MolmoModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config = config
|
||||
|
||||
self.embedding_size = config.embedding_size or config.vocab_size
|
||||
@ -1024,14 +1024,9 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_molmo)
|
||||
class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
self.config = config
|
||||
@ -1040,7 +1035,8 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
vision_config = VisionBackboneConfig()
|
||||
self.vision_backbone = MolmoVisionBackbone(config, vision_config,
|
||||
quant_config)
|
||||
self.model = MolmoModel(config, cache_config, quant_config)
|
||||
self.model = MolmoModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
if self.config.weight_tying:
|
||||
self.lm_head = self.model.transformer.wte
|
||||
|
||||
@ -26,7 +26,8 @@ from vllm.transformers_utils.configs.mpt import MPTConfig
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
def _get_alibi_slopes(
|
||||
@ -207,14 +208,13 @@ class MPTBlock(nn.Module):
|
||||
@support_torch_compile
|
||||
class MPTModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MPTConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
assert config.embedding_fraction == 1.0
|
||||
assert config.norm_type == "low_precision_layernorm"
|
||||
|
||||
@ -267,20 +267,16 @@ class MPTModel(nn.Module):
|
||||
|
||||
class MPTForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
assert config.tie_word_embeddings
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.transformer = MPTModel(config, cache_config, quant_config)
|
||||
self.transformer = MPTModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "transformer"))
|
||||
self.lm_head = self.transformer.wte
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = get_sampler()
|
||||
|
||||
@ -27,7 +27,7 @@ from torch import nn
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@ -47,7 +47,8 @@ from vllm.transformers_utils.configs import NemotronConfig
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
# The architecture is pretty similar to Llama, with these changes:
|
||||
# - There is no gate_proj, just up_proj
|
||||
@ -293,15 +294,14 @@ class NemotronDecoderLayer(nn.Module):
|
||||
@support_torch_compile
|
||||
class NemotronModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: NemotronConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
lora_vocab = (lora_config.lora_extra_vocab_size *
|
||||
@ -401,14 +401,9 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
assert isinstance(config, NemotronConfig)
|
||||
@ -416,11 +411,8 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.model = NemotronModel(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
lora_config=lora_config,
|
||||
prefix="model")
|
||||
self.model = NemotronModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
if get_pp_group().is_last_rank:
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
|
||||
@ -46,7 +46,8 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
class OlmoAttention(nn.Module):
|
||||
@ -224,12 +225,13 @@ class OlmoDecoderLayer(nn.Module):
|
||||
@support_torch_compile
|
||||
class OlmoModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: OlmoConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config = config
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
||||
@ -291,17 +293,13 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
|
||||
Extremely barebones HF model wrapper.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.model = OlmoModel(config, cache_config, quant_config)
|
||||
self.model = OlmoModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
|
||||
@ -38,7 +38,8 @@ from vllm.utils import print_warning_once
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
class OlmoeMoE(nn.Module):
|
||||
@ -243,14 +244,13 @@ class OlmoeDecoderLayer(nn.Module):
|
||||
@support_torch_compile
|
||||
class OlmoeModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
@ -309,18 +309,14 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
fall_back_to_pt_during_load = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = OlmoeModel(config, cache_config, quant_config)
|
||||
self.model = OlmoeModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
|
||||
@ -293,14 +293,13 @@ class OPTDecoder(nn.Module):
|
||||
@support_torch_compile
|
||||
class OPTModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: OPTConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.decoder = OPTDecoder(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
@ -342,21 +341,14 @@ class OPTForCausalLM(nn.Module, SupportsPP):
|
||||
".q_proj.", ".k_proj.", ".v_proj.", ".out_proj.", ".fc1.", ".fc2."
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = OPTModel(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
self.model = OPTModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head = self.model.decoder.embed_tokens
|
||||
|
||||
@ -29,7 +29,8 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
class OrionMLP(nn.Module):
|
||||
@ -208,14 +209,13 @@ class OrionDecoderLayer(nn.Module):
|
||||
@support_torch_compile
|
||||
class OrionModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
@ -268,18 +268,14 @@ class OrionModel(nn.Module):
|
||||
|
||||
class OrionForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = OrionModel(config, cache_config, quant_config)
|
||||
self.model = OrionModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
|
||||
@ -20,7 +20,7 @@ from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
|
||||
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||
merge_multimodal_embeddings)
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -131,11 +131,7 @@ class PaliGemmaMultiModalProjector(nn.Module):
|
||||
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
@ -145,7 +141,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
self.vision_tower = SiglipVisionModel(config.vision_config,
|
||||
quant_config,
|
||||
prefix="vision_tower")
|
||||
prefix=maybe_prefix(
|
||||
prefix, "vision_tower"))
|
||||
self.multi_modal_projector = PaliGemmaMultiModalProjector(
|
||||
vision_hidden_size=config.vision_config.hidden_size,
|
||||
projection_dim=config.vision_config.projection_dim)
|
||||
@ -155,7 +152,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.language_model = init_vllm_registered_model(
|
||||
config.text_config,
|
||||
vllm_config=vllm_config,
|
||||
prefix="language_model")
|
||||
prefix=maybe_prefix(prefix, "language_model"))
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
self.language_model.logits_processor.scale *= logit_scale
|
||||
|
||||
|
||||
@ -45,7 +45,8 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
class PersimmonMLP(nn.Module):
|
||||
@ -212,12 +213,13 @@ class PersimmonDecoderLayer(nn.Module):
|
||||
@support_torch_compile
|
||||
class PersimmonModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PersimmonConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
||||
@ -265,20 +267,13 @@ class PersimmonModel(nn.Module):
|
||||
|
||||
class PersimmonForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
self.model = PersimmonModel(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
self.model = PersimmonModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
|
||||
@ -60,7 +60,8 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
class PhiAttention(nn.Module):
|
||||
@ -196,12 +197,13 @@ class PhiLayer(nn.Module):
|
||||
@support_torch_compile
|
||||
class PhiModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PhiConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
||||
@ -277,14 +279,9 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
self.config = config
|
||||
@ -294,7 +291,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.model = PhiModel(config, cache_config, quant_config)
|
||||
self.model = PhiModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
|
||||
@ -24,7 +24,8 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
def load_column_parallel_weight(param: torch.nn.Parameter,
|
||||
@ -299,14 +300,13 @@ class Phi3SmallDecoderLayer(nn.Module):
|
||||
|
||||
class Phi3SmallModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config = config
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
||||
config.hidden_size)
|
||||
@ -363,18 +363,14 @@ class Phi3SmallModel(nn.Module):
|
||||
class Phi3SmallForCausalLM(nn.Module, SupportsPP):
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = Phi3SmallModel(config, cache_config, quant_config)
|
||||
self.model = Phi3SmallModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.vocab_size = config.vocab_size
|
||||
self.mup_width_multiplier = config.mup_width_multiplier
|
||||
self.lm_head = ParallelLMHead(
|
||||
|
||||
@ -45,7 +45,7 @@ from vllm.utils import is_list_of
|
||||
|
||||
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -525,11 +525,7 @@ def input_processor_for_phi3v(ctx: InputContext,
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
|
||||
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
@ -544,12 +540,14 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
quant_config=quant_config,
|
||||
prefix="model.embed_tokens",
|
||||
prefix=maybe_prefix(prefix, "model.embed_tokens"),
|
||||
)
|
||||
|
||||
# TODO: Optionally initializes this for supporting input embeddings.
|
||||
self.vision_embed_tokens = Phi3HDImageEmbedding(
|
||||
config, quant_config, prefix="model.vision_embed_tokens")
|
||||
config,
|
||||
quant_config,
|
||||
prefix=maybe_prefix(prefix, "model.vision_embed_tokens"))
|
||||
|
||||
# The prefix is empty intentionally because default prefix of
|
||||
# LlamaForCausalLM is "model"
|
||||
|
||||
@ -28,7 +28,7 @@ from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
@ -48,7 +48,8 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
class PhiMoEConfig(PretrainedConfig):
|
||||
@ -432,15 +433,14 @@ class PhiMoEDecoderLayer(nn.Module):
|
||||
@support_torch_compile
|
||||
class PhiMoEModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PhiMoEConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.padding_idx = config.pad_token_id
|
||||
lora_vocab = ((lora_config.lora_extra_vocab_size *
|
||||
(lora_config.max_loras or 1)) if lora_config else 0)
|
||||
@ -529,23 +529,15 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
}
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.model = PhiMoEModel(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
lora_config=lora_config)
|
||||
self.model = PhiMoEModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
|
||||
@ -38,7 +38,7 @@ from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import init_vllm_registered_model
|
||||
from .utils import init_vllm_registered_model, maybe_prefix
|
||||
|
||||
try:
|
||||
from xformers import ops as xops
|
||||
@ -152,11 +152,7 @@ def input_processor_for_pixtral(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
@ -176,7 +172,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.language_model = init_vllm_registered_model(
|
||||
config.text_config,
|
||||
vllm_config=vllm_config,
|
||||
prefix="language_model")
|
||||
prefix=maybe_prefix(prefix, "language_model"))
|
||||
|
||||
self.vision_encoder = VisionTransformer(self.vision_args)
|
||||
self.vision_language_adapter = VisionLanguageAdapter(
|
||||
|
||||
@ -50,7 +50,8 @@ from vllm.utils import is_list_of
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
|
||||
from .utils import (flatten_bn, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -552,14 +553,13 @@ class QWenBlock(nn.Module):
|
||||
@support_torch_compile
|
||||
class QWenModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
@ -865,20 +865,17 @@ def dummy_data_for_qwen(
|
||||
|
||||
class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.quant_config = quant_config
|
||||
self.transformer = QWenModel(config, cache_config, quant_config)
|
||||
self.transformer = QWenModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "transformer"))
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
|
||||
@ -240,14 +240,13 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
@support_torch_compile
|
||||
class Qwen2Model(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen2Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
@ -403,11 +402,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
@ -429,9 +424,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.model = Qwen2Model(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
self.model = Qwen2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
if config.tie_word_embeddings:
|
||||
|
||||
@ -264,14 +264,9 @@ def input_mapper_for_qwen2_audio(
|
||||
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
self.config = config
|
||||
@ -283,8 +278,9 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.language_model = Qwen2Model(config.text_config, cache_config,
|
||||
quant_config)
|
||||
self.language_model = Qwen2Model(
|
||||
vllm_config=vllm_config.with_hf_config(config.text_config),
|
||||
prefix=prefix)
|
||||
self.unpadded_vocab_size = config.text_config.vocab_size
|
||||
if config.text_config.tie_word_embeddings:
|
||||
self.lm_head = self.language_model.embed_tokens
|
||||
|
||||
@ -17,7 +17,7 @@ from vllm.model_executor.models.qwen2 import Qwen2Model
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
|
||||
from .utils import AutoWeightsLoader
|
||||
from .utils import AutoWeightsLoader, maybe_prefix
|
||||
|
||||
|
||||
class Qwen2ForSequenceClassification(nn.Module):
|
||||
@ -43,11 +43,7 @@ class Qwen2ForSequenceClassification(nn.Module):
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
@ -70,7 +66,8 @@ class Qwen2ForSequenceClassification(nn.Module):
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.model = Qwen2Model(config, cache_config, quant_config)
|
||||
self.model = Qwen2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
self.score = RowParallelLinear(config.hidden_size,
|
||||
config.num_labels,
|
||||
|
||||
@ -54,7 +54,8 @@ from vllm.utils import print_warning_once
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
class Qwen2MoeMLP(nn.Module):
|
||||
@ -315,14 +316,13 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
@support_torch_compile
|
||||
class Qwen2MoeModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
@ -377,18 +377,14 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
fall_back_to_pt_during_load = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = Qwen2MoeModel(config, cache_config, quant_config)
|
||||
self.model = Qwen2MoeModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
|
||||
@ -18,7 +18,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .qwen2 import Qwen2Model
|
||||
from .utils import AutoWeightsLoader
|
||||
from .utils import AutoWeightsLoader, maybe_prefix
|
||||
|
||||
|
||||
class ReLU(nn.Module):
|
||||
@ -55,11 +55,7 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP):
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
@ -82,7 +78,8 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP):
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.model = Qwen2Model(config, cache_config, quant_config)
|
||||
self.model = Qwen2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
self.score = nn.Sequential(
|
||||
ColumnParallelLinear(config.hidden_size,
|
||||
|
||||
@ -70,7 +70,7 @@ from vllm.transformers_utils.processor import cached_get_processor
|
||||
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
|
||||
from .utils import (PPMissingLayer, get_vit_attn_backend,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory)
|
||||
make_empty_intermediate_tensors_factory, maybe_prefix)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -966,11 +966,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
@ -986,13 +982,11 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=self._maybe_ignore_quant_config(quant_config),
|
||||
prefix="visual",
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
)
|
||||
|
||||
self.model = Qwen2Model(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix="model")
|
||||
self.model = Qwen2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
if config.tie_word_embeddings:
|
||||
@ -1001,7 +995,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix="lm_head")
|
||||
prefix=maybe_prefix(
|
||||
prefix, "lm_head"))
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
|
||||
@ -29,7 +29,7 @@ from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
@ -53,7 +53,8 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
class SolarMLP(nn.Module):
|
||||
@ -266,15 +267,14 @@ class SolarDecoderLayer(nn.Module):
|
||||
@support_torch_compile
|
||||
class SolarModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
lora_vocab = ((lora_config.lora_extra_vocab_size *
|
||||
@ -409,25 +409,17 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.model = SolarModel(
|
||||
config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
lora_config=lora_config,
|
||||
prefix="model",
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"),
|
||||
)
|
||||
if get_pp_group().is_last_rank:
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
|
||||
@ -43,7 +43,8 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
class StablelmMLP(nn.Module):
|
||||
@ -193,12 +194,13 @@ class StablelmDecoderLayer(nn.Module):
|
||||
|
||||
class StableLMEpochModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = '') -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
@ -245,18 +247,14 @@ class StableLMEpochModel(nn.Module):
|
||||
|
||||
class StablelmForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = StableLMEpochModel(config, cache_config, quant_config)
|
||||
self.model = StableLMEpochModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
|
||||
@ -43,7 +43,8 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
class Starcoder2Attention(nn.Module):
|
||||
@ -195,12 +196,13 @@ class Starcoder2DecoderLayer(nn.Module):
|
||||
@support_torch_compile
|
||||
class Starcoder2Model(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: Starcoder2Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
@ -245,19 +247,13 @@ class Starcoder2Model(nn.Module):
|
||||
|
||||
class Starcoder2ForCausalLM(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.model = Starcoder2Model(config,
|
||||
cache_config,
|
||||
quant_config=quant_config)
|
||||
self.model = Starcoder2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.vocab_size = config.vocab_size
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if config.tie_word_embeddings:
|
||||
|
||||
@ -34,7 +34,7 @@ from vllm.utils import is_list_of
|
||||
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings_from_map)
|
||||
|
||||
_AUDIO_PLACEHOLDER_TOKEN = 128002
|
||||
@ -339,11 +339,7 @@ class ModifiedWhisperEncoder(WhisperEncoder):
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox)
|
||||
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
@ -354,6 +350,8 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self.secondary_weights = []
|
||||
self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
|
||||
if config.audio_model_id is not None:
|
||||
# this prefix is not for initialization, but for loading weights
|
||||
# note the trailing dot
|
||||
self.secondary_weights.append(
|
||||
DefaultModelLoader.Source(
|
||||
model_or_path=config.audio_model_id,
|
||||
@ -362,8 +360,12 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
))
|
||||
self.multi_modal_projector = UltravoxProjector(config)
|
||||
self.language_model = init_vllm_registered_model(
|
||||
config.text_config, vllm_config, prefix="language_model")
|
||||
config.text_config,
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"))
|
||||
if config.text_model_id is not None:
|
||||
# this prefix is not for initialization, but for loading weights
|
||||
# note the trailing dot
|
||||
self.secondary_weights.append(
|
||||
DefaultModelLoader.Source(model_or_path=config.text_model_id,
|
||||
revision=None,
|
||||
|
||||
@ -46,7 +46,8 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
|
||||
class XverseMLP(nn.Module):
|
||||
@ -223,11 +224,7 @@ class XverseDecoderLayer(nn.Module):
|
||||
@support_torch_compile
|
||||
class XverseModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
@ -315,15 +312,10 @@ class XverseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
}
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
@ -331,7 +323,8 @@ class XverseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.model = XverseModel(config, cache_config, quant_config)
|
||||
self.model = XverseModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user