mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-09 00:18:29 +08:00
[Model][Bugfix] Implicit model flags and reenable Phi-3-Vision (#5896)
This commit is contained in:
parent
e9d32d077d
commit
98cf2ed678
@ -295,8 +295,6 @@ class BaiChuanModel(nn.Module):
|
||||
|
||||
|
||||
class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
|
||||
supports_lora = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"W_pack": ["W_pack"],
|
||||
"gate_up_proj": [
|
||||
|
||||
@ -325,8 +325,6 @@ class ChatGLMModel(nn.Module):
|
||||
|
||||
|
||||
class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
|
||||
supports_lora = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"query_key_value": ["query_key_value"],
|
||||
"dense_h_to_4h": ["dense_h_to_4h"]
|
||||
|
||||
@ -291,8 +291,6 @@ class GemmaModel(nn.Module):
|
||||
|
||||
|
||||
class GemmaForCausalLM(nn.Module, SupportsLoRA):
|
||||
supports_lora = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
|
||||
@ -233,8 +233,6 @@ class GPTBigCodeModel(nn.Module):
|
||||
|
||||
|
||||
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
|
||||
supports_lora = True
|
||||
|
||||
packed_modules_mapping = {"c_attn": ["c_attn"]}
|
||||
|
||||
supported_lora_modules = ["c_fc", "c_proj", "wte", "lm_head", "c_attn"]
|
||||
|
||||
@ -13,7 +13,14 @@ logger = init_logger(__name__)
|
||||
class SupportsVision(Protocol):
|
||||
"""The interface required for all vision language models (VLMs)."""
|
||||
|
||||
supports_vision: ClassVar[Literal[True]]
|
||||
supports_vision: ClassVar[Literal[True]] = True
|
||||
"""
|
||||
A flag that indicates this model supports vision inputs.
|
||||
|
||||
Note:
|
||||
There is no need to redefine this flag if this class is in the
|
||||
MRO of your model class.
|
||||
"""
|
||||
|
||||
def __init__(self, *, vlm_config: VisionLanguageConfig) -> None:
|
||||
...
|
||||
@ -52,7 +59,14 @@ def supports_vision(
|
||||
class SupportsLoRA(Protocol):
|
||||
"""The interface required for all models that support LoRA."""
|
||||
|
||||
supports_lora: ClassVar[Literal[True]]
|
||||
supports_lora: ClassVar[Literal[True]] = True
|
||||
"""
|
||||
A flag that indicates this model supports LoRA.
|
||||
|
||||
Note:
|
||||
There is no need to redefine this flag if this class is in the
|
||||
MRO of your model class.
|
||||
"""
|
||||
|
||||
packed_modules_mapping: ClassVar[Dict[str, List[str]]]
|
||||
supported_lora_modules: ClassVar[List[str]]
|
||||
|
||||
@ -299,8 +299,6 @@ class LlamaModel(nn.Module):
|
||||
|
||||
|
||||
class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
||||
supports_lora = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
|
||||
@ -88,8 +88,6 @@ LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs]
|
||||
@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data)
|
||||
class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
supports_vision = True
|
||||
|
||||
def __init__(self,
|
||||
config: LlavaConfig,
|
||||
vlm_config: VisionLanguageConfig,
|
||||
|
||||
@ -108,8 +108,6 @@ def _image_pixel_processor(
|
||||
@MULTIMODAL_REGISTRY.register_dummy_data(_get_dummy_image_data)
|
||||
class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
supports_vision = True
|
||||
|
||||
def __init__(self,
|
||||
config: LlavaNextConfig,
|
||||
vlm_config: VisionLanguageConfig,
|
||||
|
||||
@ -392,8 +392,6 @@ class MiniCPMModel(nn.Module):
|
||||
|
||||
|
||||
class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
|
||||
supports_lora = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
|
||||
@ -475,8 +475,6 @@ class MixtralModel(nn.Module):
|
||||
|
||||
|
||||
class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
||||
supports_lora = True
|
||||
|
||||
fall_back_to_pt_during_load = False
|
||||
|
||||
packed_modules_mapping = {
|
||||
|
||||
@ -232,8 +232,6 @@ class PhiModel(nn.Module):
|
||||
|
||||
|
||||
class PhiForCausalLM(nn.Module, SupportsLoRA):
|
||||
supports_lora = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
|
||||
@ -32,12 +32,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.models.vlm_base import VisionLanguageModelBase
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.image import ImagePixelData, get_dummy_image_data
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
from .interfaces import SupportsVision
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_KEYS_TO_MODIFY_MAPPING = {
|
||||
@ -317,18 +318,21 @@ def _image_processor(
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_pixel_input(_image_processor)
|
||||
@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data)
|
||||
class Phi3VForCausalLM(VisionLanguageModelBase):
|
||||
class Phi3VForCausalLM(nn.Module, SupportsVision):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
vision_language_config: VisionLanguageConfig,
|
||||
vlm_config: VisionLanguageConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
super().__init__(vision_language_config)
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.vlm_config = vlm_config
|
||||
|
||||
self.model = LlamaModel(config, cache_config, quant_config)
|
||||
self.vision_embed_tokens = Phi3HDImageEmbedding(
|
||||
vision_language_config, config, self.model.embed_tokens)
|
||||
vlm_config, config, self.model.embed_tokens)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
@ -338,7 +342,7 @@ class Phi3VForCausalLM(VisionLanguageModelBase):
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_sizes = kwargs.pop("image_sizes", None)
|
||||
|
||||
expected_input_type = self.vision_language_config.image_input_type
|
||||
expected_input_type = self.vlm_config.image_input_type
|
||||
ImageInputType = VisionLanguageConfig.ImageInputType
|
||||
|
||||
if expected_input_type != ImageInputType.PIXEL_VALUES:
|
||||
|
||||
@ -266,8 +266,6 @@ class Qwen2Model(nn.Module):
|
||||
|
||||
|
||||
class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
|
||||
supports_lora = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
|
||||
@ -269,8 +269,6 @@ class XverseModel(nn.Module):
|
||||
|
||||
|
||||
class XverseForCausalLM(nn.Module, SupportsLoRA):
|
||||
supports_lora = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user