diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py new file mode 100644 index 000000000000..547ab10051f1 --- /dev/null +++ b/tests/models/test_registry.py @@ -0,0 +1,9 @@ +import pytest + +from vllm.model_executor.models import _MODELS, ModelRegistry + + +@pytest.mark.parametrize("model_cls", _MODELS) +def test_registry_imports(model_cls): + # Ensure all model classes can be imported successfully + ModelRegistry.load_model_cls(model_cls) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index dc568928b285..d1ab20754979 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -26,11 +26,7 @@ from vllm.model_executor.model_loader.weight_utils import ( download_weights_from_hf, filter_files_not_needed_for_inference, get_quant_config, initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator, safetensors_weights_iterator) -from vllm.model_executor.models.llava import LlavaForConditionalGeneration - -_VISION_MODEL_CLASSES = [ - LlavaForConditionalGeneration, -] +from vllm.model_executor.models.vlm_base import VisionLanguageModelBase logger = init_logger(__name__) @@ -73,7 +69,12 @@ def _get_model_initialization_kwargs( "but LoRA is enabled. Support for this model may " "be added in the future. If this is important to you, " "please open an issue on github.") - elif model_class in _VISION_MODEL_CLASSES: + elif issubclass(model_class, VisionLanguageModelBase): + if vision_language_config is None: + raise ValueError("Provide `image_input_type` and other vision " + "related configurations through LLM entrypoint " + "or engine arguments.") + extra_kwargs["vision_language_config"] = vision_language_config return extra_kwargs diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 3b99b337a276..e8a5b6237d4d 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -19,6 +19,8 @@ from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput +from .vlm_base import VisionLanguageModelBase + _KEYS_TO_MODIFY_MAPPING = { "language_model.lm_head": "lm_head", "language_model.model": "language_model", @@ -40,7 +42,7 @@ class LlavaMultiModalProjector(nn.Module): text_hidden_size, bias=True) - def forward(self, image_features): + def forward(self, image_features: torch.Tensor) -> torch.Tensor: hidden_states = self.linear_1(image_features) hidden_states = self.act(hidden_states) hidden_states = self.linear_2(hidden_states) @@ -50,30 +52,32 @@ class LlavaMultiModalProjector(nn.Module): def _merge_vision_embeddings(input_ids: torch.Tensor, inputs_embeds: torch.Tensor, vision_embeddings: torch.Tensor, - image_token_id: int): + image_token_id: int) -> torch.Tensor: """In place merges in vision_embeddings with inputs_embeds.""" mask = (input_ids == image_token_id) - inputs_embeds[mask] = vision_embeddings.view(-1, + + image_feature_size = vision_embeddings.shape[0] * vision_embeddings.shape[1] + if mask.sum() != image_feature_size: + raise ValueError(f"image_feature_size should be {image_feature_size}, " + f"but found: {mask.sum()}") + + inputs_embeds[mask] = vision_embeddings.view(image_feature_size, vision_embeddings.shape[-1]) + return inputs_embeds -class LlavaForConditionalGeneration(nn.Module): + +class LlavaForConditionalGeneration(VisionLanguageModelBase): def __init__(self, - config: "LlavaConfig", + config: LlavaConfig, vision_language_config: VisionLanguageConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional["QuantizationConfig"] = None) -> None: - super().__init__() + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__(vision_language_config) + self.config = config - self.vision_language_config = vision_language_config - - assert self.vision_language_config, ( - "Provide `image_input_type` and other vision " - "related configurations through LLM entrypoint " - "or engine arguments.") - if self.vision_language_config.image_input_type == ( VisionLanguageConfig.ImageInputType.PIXEL_VALUES): self.vision_tower = CLIPVisionModel(config.vision_config) @@ -98,14 +102,12 @@ class LlavaForConditionalGeneration(nn.Module): config.vocab_size, logit_scale) self.sampler = Sampler() - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - image_input: Optional[torch.Tensor] = None - ) -> SamplerOutput: # noqa: E501 + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + image_input: Optional[torch.Tensor] = None) -> SamplerOutput: """Run forward pass for Llava 1.5. One key thing to understand is the `input_ids` already accounts for the @@ -172,7 +174,7 @@ class LlavaForConditionalGeneration(nn.Module): image_features = image_input vision_embeddings = self.multi_modal_projector(image_features) inputs_embeds = self.language_model.get_input_embeddings(input_ids) - _merge_vision_embeddings( + inputs_embeds = _merge_vision_embeddings( input_ids, inputs_embeds, vision_embeddings, self.vision_language_config.image_token_id) input_ids = None diff --git a/vllm/model_executor/models/vlm_base.py b/vllm/model_executor/models/vlm_base.py new file mode 100644 index 000000000000..eb0aa96e50d5 --- /dev/null +++ b/vllm/model_executor/models/vlm_base.py @@ -0,0 +1,12 @@ +from torch import nn + +from vllm.config import VisionLanguageConfig + + +class VisionLanguageModelBase(nn.Module): + """Base class for all vision language models (VLMs).""" + + def __init__(self, vision_language_config: VisionLanguageConfig) -> None: + super().__init__() + + self.vision_language_config = vision_language_config