[Bugfix] fix composite weight loading and EAGLE weight loading (#9160)

This commit is contained in:
Cyrus Leung 2024-10-09 15:36:55 +08:00 committed by GitHub
parent 0b5b5d767e
commit 8bfaa4e31e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 241 additions and 361 deletions

View File

@ -13,7 +13,6 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors, SequenceData from vllm.sequence import IntermediateTensors, SequenceData
@ -21,7 +20,7 @@ from vllm.sequence import IntermediateTensors, SequenceData
from .blip import (BlipVisionModel, dummy_image_for_blip, from .blip import (BlipVisionModel, dummy_image_for_blip,
get_max_blip_image_tokens) get_max_blip_image_tokens)
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (group_weights_with_prefix, init_vllm_registered_model, from .utils import (AutoWeightsLoader, init_vllm_registered_model,
merge_multimodal_embeddings) merge_multimodal_embeddings)
# We use this internally as placeholders since there is no image token # We use this internally as placeholders since there is no image token
@ -687,35 +686,5 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components loader = AutoWeightsLoader(self)
weights_group = group_weights_with_prefix(weights) loader.load_weights(weights)
# load vision encoder
self.vision_model.load_weights(weights_group["vision_model"])
# load query tokens
for name, loaded_weight in weights_group["query_tokens"]:
assert name == ""
param = self.query_tokens
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load qformer
qformer_params_dict = dict(self.qformer.named_parameters())
for name, loaded_weight in weights_group["qformer"]:
param = qformer_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load mlp projector
mlp_params_dict = dict(self.language_projection.named_parameters())
for name, loaded_weight in weights_group["language_projection"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])

View File

@ -31,7 +31,6 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.persimmon import PersimmonForCausalLM from vllm.model_executor.models.persimmon import PersimmonForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
@ -42,8 +41,7 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData) SequenceData)
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, group_weights_with_prefix, from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings
merge_multimodal_embeddings)
# Cannot find the following 2 numbers from hf config. # Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011 _IMAGE_TOKEN_ID = 71011
@ -349,16 +347,5 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components loader = AutoWeightsLoader(self)
weights_group = group_weights_with_prefix(weights) loader.load_weights(weights)
# load vision embeddings
vision_params_dict = dict(self.vision_embed_tokens.named_parameters())
for name, loaded_weight in weights_group["vision_embed_tokens"]:
param = vision_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])

View File

@ -40,7 +40,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (group_weights_with_prefix, is_pp_missing_parameter, from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -447,19 +447,9 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weights_group = group_weights_with_prefix(weights) loader = AutoWeightsLoader(
self,
self.model.load_weights(weights_group["model"]) skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
if not self.config.tie_word_embeddings: )
# NOTE: For now self.lm_head is not defined because loader.load_weights(weights)
# tie_word_embeddings is assumed to the False
lm_head_dict = dict(self.lm_head.named_parameters())
for name, loaded_weight in weights_group["lm_head"]:
if is_pp_missing_parameter(name, self.lm_head):
continue
param = lm_head_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@ -20,7 +20,6 @@ from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.intern_vit import InternVisionModel from vllm.model_executor.models.intern_vit import InternVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
@ -32,8 +31,8 @@ from vllm.utils import is_list_of
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_num_patches) get_clip_num_patches)
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, group_weights_with_prefix, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
init_vllm_registered_model, merge_multimodal_embeddings) merge_multimodal_embeddings)
IMG_START = '<img>' IMG_START = '<img>'
IMG_END = '</img>' IMG_END = '</img>'
@ -609,19 +608,5 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components loader = AutoWeightsLoader(self)
weights_group = group_weights_with_prefix(weights) loader.load_weights(weights)
# load vision encoder
self.vision_model.load_weights(weights_group["vision_model"])
# load mlp projector
mlp_params_dict = dict(self.mlp1.named_parameters())
for name, loaded_weight in weights_group["mlp1"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])

View File

@ -51,8 +51,7 @@ from vllm.sequence import IntermediateTensors
from vllm.utils import is_hip from vllm.utils import is_hip
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, group_weights_with_prefix, from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers)
@ -564,25 +563,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weights = [ loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
loader.load_weights(
self.maybe_remap_mistral(name, loaded_weight) self.maybe_remap_mistral(name, loaded_weight)
for name, loaded_weight in weights for name, loaded_weight in weights)
]
weights_group = group_weights_with_prefix(weights)
self.model.load_weights(weights_group["model"])
if not self.config.tie_word_embeddings:
lm_head_dict = dict(self.lm_head.named_parameters())
for name, loaded_weight in weights_group["lm_head"]:
if is_pp_missing_parameter(name, self.lm_head):
continue
param = lm_head_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
def load_kv_cache_scales(self, quantization_param_path: str) -> None: def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path) self.model.load_kv_cache_scales(quantization_param_path)

View File

@ -13,7 +13,6 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
@ -26,8 +25,8 @@ from .interfaces import SupportsMultiModal, SupportsPP
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens, dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
input_processor_for_siglip) input_processor_for_siglip)
from .utils import (flatten_bn, group_weights_with_prefix, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
init_vllm_registered_model, merge_multimodal_embeddings) merge_multimodal_embeddings)
class LlavaImagePixelInputs(TypedDict): class LlavaImagePixelInputs(TypedDict):
@ -406,19 +405,5 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components loader = AutoWeightsLoader(self)
weights_group = group_weights_with_prefix(weights) loader.load_weights(weights)
# load vision encoder
self.vision_tower.load_weights(weights_group["vision_tower"])
# load mlp projector
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])

View File

@ -15,7 +15,6 @@ from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
@ -29,8 +28,8 @@ from .llava import LlavaMultiModalProjector
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_siglip_image_feature_size, dummy_seq_data_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip) get_siglip_patch_grid_length, input_processor_for_siglip)
from .utils import (flatten_bn, group_weights_with_prefix, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
init_vllm_registered_model, merge_multimodal_embeddings) merge_multimodal_embeddings)
# Result in the max possible feature size (2x2 grid of 336x336px tiles) # Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448 MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
@ -642,27 +641,5 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components loader = AutoWeightsLoader(self)
weights_group = group_weights_with_prefix(weights) loader.load_weights(weights)
# load vision encoder
self.vision_tower.load_weights(weights_group["vision_tower"])
# load mlp projector
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load newline
for name, loaded_weight in weights_group["image_newline"]:
assert name == ""
param = self.image_newline
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])

View File

@ -15,7 +15,6 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
@ -28,7 +27,7 @@ from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip) dummy_seq_data_for_siglip)
from .utils import (group_weights_with_prefix, init_vllm_registered_model, from .utils import (AutoWeightsLoader, init_vllm_registered_model,
merge_multimodal_embeddings) merge_multimodal_embeddings)
# For profile run # For profile run
@ -458,19 +457,9 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components loader = AutoWeightsLoader(
weights_group = group_weights_with_prefix(weights) self,
# This model doesn't support images for now
# load vision encoder ignore_unexpected_prefixes=["image_newline"],
self.vision_tower.load_weights(weights_group["vision_tower"]) )
loader.load_weights(weights)
# load mlp projector
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])

View File

@ -20,7 +20,6 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import (cached_get_tokenizer, from vllm.multimodal.utils import (cached_get_tokenizer,
@ -35,8 +34,8 @@ from .interfaces import SupportsMultiModal, SupportsPP
from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip, from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip,
dummy_video_for_siglip, get_siglip_image_feature_size, dummy_video_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip) get_siglip_patch_grid_length, input_processor_for_siglip)
from .utils import (flatten_bn, group_weights_with_prefix, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
init_vllm_registered_model, merge_multimodal_embeddings) merge_multimodal_embeddings)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -872,19 +871,5 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components loader = AutoWeightsLoader(self)
weights_group = group_weights_with_prefix(weights) loader.load_weights(weights)
# load vision encoder
self.vision_tower.load_weights(weights_group["vision_tower"])
# load mlp projector
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])

View File

@ -11,7 +11,6 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.gemma import GemmaForCausalLM from vllm.model_executor.models.gemma import GemmaForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
@ -21,7 +20,7 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens) dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
from .utils import group_weights_with_prefix, merge_multimodal_embeddings from .utils import AutoWeightsLoader, merge_multimodal_embeddings
logger = init_logger(__name__) logger = init_logger(__name__)
@ -292,19 +291,5 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components loader = AutoWeightsLoader(self)
weights_group = group_weights_with_prefix(weights) loader.load_weights(weights)
# load vision tower
self.vision_tower.load_weights(weights_group["vision_tower"])
# load mlp projector
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])

View File

@ -31,7 +31,6 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
@ -42,15 +41,11 @@ from vllm.utils import is_list_of
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, group_weights_with_prefix, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
merge_multimodal_embeddings) merge_multimodal_embeddings)
logger = init_logger(__name__) logger = init_logger(__name__)
_KEYS_TO_MODIFY_MAPPING = {
"model.vision_embed_tokens": "vision_embed_tokens",
}
# Cannot find the following 2 numbers from hf config. # Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 32044 _IMAGE_TOKEN_ID = 32044
@ -295,35 +290,8 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
return image_features_hd_newline return image_features_hd_newline
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components loader = AutoWeightsLoader(self)
weights_group = group_weights_with_prefix(weights) loader.load_weights(weights)
# load vision encoder
self.img_processor.load_weights(weights_group["img_processor"])
# load glb_GN
for name, loaded_weight in weights_group["glb_GN"]:
assert name == ""
param = self.glb_GN
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load sub_GN
for name, loaded_weight in weights_group["sub_GN"]:
assert name == ""
param = self.sub_GN
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load mlp projector
mlp_params_dict = dict(self.img_projection.named_parameters())
for name, loaded_weight in weights_group["img_projection"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57 # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
@ -715,27 +683,12 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
hf_to_vllm_mapping = { hf_to_vllm_mapper = WeightsMapper(
"model.vision_embed_tokens.": "vision_embed_tokens.", orig_to_new_prefix={
"lm_head.": "language_model.lm_head.", "model.vision_embed_tokens.": "vision_embed_tokens.",
"model.": "language_model.model.", "lm_head.": "language_model.lm_head.",
} "model.": "language_model.model.",
})
def hf_to_vllm_name(key: str) -> str: loader = AutoWeightsLoader(self)
for hf_name, vllm_name in hf_to_vllm_mapping.items(): loader.load_weights(weights, mapper=hf_to_vllm_mapper)
if key.startswith(hf_name):
return key.replace(hf_name, vllm_name, 1)
return key
vllm_weights = {hf_to_vllm_name(k): v for k, v in weights}
# prepare weight iterators for components
weights_group = group_weights_with_prefix(vllm_weights.items())
# load vision embeddings and encoder
self.vision_embed_tokens.load_weights(
weights_group["vision_embed_tokens"])
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])

View File

@ -48,8 +48,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, group_weights_with_prefix, from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers)
@ -435,17 +434,9 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weights_group = group_weights_with_prefix(weights) loader = AutoWeightsLoader(
self,
self.model.load_weights(weights_group["model"]) skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
if not self.config.tie_word_embeddings: )
lm_head_dict = dict(self.lm_head.named_parameters()) loader.load_weights(weights)
for name, loaded_weight in weights_group["lm_head"]:
if is_pp_missing_parameter(name, self.lm_head):
continue
param = lm_head_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@ -16,13 +16,12 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .qwen2 import Qwen2Model from .qwen2 import Qwen2Model
from .utils import group_weights_with_prefix from .utils import AutoWeightsLoader
class ReLU(nn.Module): class ReLU(nn.Module):
@ -120,13 +119,5 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP):
return self._pooler(hidden_states, pooling_metadata) return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weights_group = group_weights_with_prefix(weights) loader = AutoWeightsLoader(self)
loader.load_weights(weights)
self.model.load_weights(weights_group["model"])
score_dict = dict(self.score.named_parameters())
for name, loaded_weight in weights_group["score"]:
param = score_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@ -25,11 +25,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import (flatten_bn,
group_weights_with_prefix,
init_vllm_registered_model,
merge_multimodal_embeddings)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs, NestedTensors from vllm.multimodal.base import MultiModalInputs, NestedTensors
@ -41,6 +36,8 @@ from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, merge_multimodal_embeddings)
_AUDIO_PLACEHOLDER_TOKEN = 128002 _AUDIO_PLACEHOLDER_TOKEN = 128002
_AUDIO_TOKENS_PER_SECOND = 6.25 _AUDIO_TOKENS_PER_SECOND = 6.25
@ -498,30 +495,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
return self.language_model.sample(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components hf_to_vllm_mapper = WeightsMapper(
weights_group = group_weights_with_prefix(weights) orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."})
# load audio tower weights loader = AutoWeightsLoader(self,
audio_tower_weights = weights_group["audio_tower"] ignore_unexpected_prefixes=["audio_tower."])
audio_tower_params_dict = dict( loader.load_weights(weights, mapper=hf_to_vllm_mapper)
self.audio_tower.named_parameters(
prefix=self.audio_tower.base_model_prefix))
for name, loaded_weight in audio_tower_weights:
if name in audio_tower_params_dict:
param = audio_tower_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load projector weights
projector_weights = weights_group["multi_modal_projector"]
projector_params_dict = dict(
self.multi_modal_projector.named_parameters())
for name, loaded_weight in projector_weights:
param = projector_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])

View File

@ -1,7 +1,7 @@
import itertools import itertools
from collections import UserDict from dataclasses import dataclass, field
from typing import (Any, Dict, Iterable, List, Literal, Optional, Protocol, from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
Tuple, Union, overload) Protocol, Tuple, Union, overload)
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -12,55 +12,184 @@ from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig,
SchedulerConfig) SchedulerConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.loader import build_model from vllm.model_executor.model_loader.loader import build_model
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.multimodal.base import NestedTensors from vllm.multimodal.base import NestedTensors
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
WeightsMapping = Mapping[str, Optional[str]]
"""If a key maps to a value of `None`, the corresponding weight is ignored."""
class WeightsGroup(UserDict):
@dataclass
class WeightsMapper:
"""Maps the name of each weight if they match the following patterns."""
orig_to_new_substr: WeightsMapping = field(default_factory=dict)
orig_to_new_prefix: WeightsMapping = field(default_factory=dict)
orig_to_new_suffix: WeightsMapping = field(default_factory=dict)
def _map_name(self, key: str) -> Optional[str]:
for substr, new_key in self.orig_to_new_substr.items():
if substr in key:
if new_key is None:
return None
key = key.replace(substr, new_key, 1)
for prefix, new_key in self.orig_to_new_prefix.items():
if key.startswith(prefix):
if new_key is None:
return None
key = key.replace(prefix, new_key, 1)
for suffix, new_key in self.orig_to_new_suffix.items():
if key.endswith(suffix):
if new_key is None:
return None
key = new_key.join(key.rsplit(suffix, 1))
return key
def apply(
self, weights: Iterable[Tuple[str, torch.Tensor]]
) -> Iterable[Tuple[str, torch.Tensor]]:
return ((out_name, data) for name, data in weights
if (out_name := self._map_name(name)) is not None)
class AutoWeightsLoader:
""" """
Wraps grouped weights dictionary for a more informative error message Helper class to load weights into a :class:`torch.nn.Module`. It is able
when attempting to access a weight component that does not exist. to automatically detect child modules and parameters while iterating over
the weights only once.
The weight loading logic for individual modules can be overridden
by defining a ``load_weights`` method.
Similarly, the weight loading logic for individual parameters can be
overridden by defining a ``weight_loader`` method.
""" """
def __getitem__(self, key: str) -> Iterable[Tuple[str, torch.Tensor]]: def __init__(
try: self,
return super().__getitem__(key) module: nn.Module,
except KeyError as exc: *,
msg = (f"There is no weights named with the prefix: {key}. " skip_prefixes: Optional[List[str]] = None,
f"Available prefix: {set(self.keys())}") ignore_unexpected_prefixes: Optional[List[str]] = None,
raise KeyError(msg) from exc ) -> None:
super().__init__()
self.module = module
self.skip_prefixes = skip_prefixes or []
self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or []
def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], def _groupby_prefix(
prefix: str) -> Iterable[Tuple[str, torch.Tensor]]: self,
""" weights: Iterable[Tuple[str, torch.Tensor]],
Helper function to load weights for inner vLLM models. ) -> Iterable[Tuple[str, Iterable[Tuple[str, torch.Tensor]]]]:
weights_by_parts = ((weight_name.split(".", 1), weight_data)
for weight_name, weight_data in weights)
See also: for prefix, group in itertools.groupby(weights_by_parts,
:ref:`init_vllm_registered_model` key=lambda x: x[0][0]):
""" yield (
for name, loaded_weight in weights: prefix,
name = name.split(".") # Because maxsplit=1 in weight_name.split(...),
if prefix == name.pop(0): # the length of `parts` must either be 1 or 2
name = ".".join(name) (("" if len(parts) == 1 else parts[1], weights_data)
yield name, loaded_weight for parts, weights_data in group),
)
def _get_qualname(self, prefix: str, rest: str) -> str:
if prefix == "":
return rest
if rest == "":
return prefix
def group_weights_with_prefix( return ".".join((prefix, rest))
weights: Iterable[Tuple[str, torch.Tensor]], ) -> WeightsGroup:
"""
Helper function to group weights with prefix
"""
init_weights, repeated_weights = itertools.tee(weights, 2)
weights_prefix = {name.split(".")[0] for name, _ in init_weights}
repeated_weights = itertools.tee(repeated_weights, len(weights_prefix))
return WeightsGroup({ def _can_skip(self, qualname: str) -> bool:
prefix: filter_weights(component, prefix) return any(qualname.startswith(p) for p in self.skip_prefixes)
for component, prefix in zip(repeated_weights, weights_prefix)
}) def _can_ignore_unexpected(self, qualname: str) -> bool:
return any(
qualname.startswith(p) for p in self.ignore_unexpected_prefixes)
def _load_param(
self,
base_prefix: str,
param: nn.Parameter,
weights: Iterable[Tuple[str, torch.Tensor]],
) -> None:
for weight_name, weight_data in weights:
weight_qualname = self._get_qualname(base_prefix, weight_name)
if self._can_skip(weight_qualname):
continue
if weight_name != "":
if not self._can_ignore_unexpected(weight_qualname):
raise ValueError(
f"Attempted to load nested weight '{weight_qualname}' "
f"into a single parameter '{base_prefix}'")
continue
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, weight_data)
def _load_module(
self,
base_prefix: str,
module: nn.Module,
weights: Iterable[Tuple[str, torch.Tensor]],
) -> None:
if isinstance(module, PPMissingLayer):
return
# Avoid infinite recursion since this function is typically
# called inside load_weights of the module itself
if module != self.module:
module_load_weights = getattr(module, "load_weights", None)
if callable(module_load_weights):
module_load_weights(weights)
return
child_modules = dict(module.named_children())
child_params = dict(module.named_parameters(recurse=False))
for child_prefix, child_weights in self._groupby_prefix(weights):
prefix = self._get_qualname(base_prefix, child_prefix)
if self._can_skip(prefix):
continue
if child_prefix in child_modules:
self._load_module(prefix, child_modules[child_prefix],
child_weights)
elif child_prefix in child_params:
self._load_param(prefix, child_params[child_prefix],
child_weights)
else:
if not self._can_ignore_unexpected(prefix):
msg = f"There is no module or parameter named '{prefix}'"
raise ValueError(msg)
def load_weights(
self,
weights: Iterable[Tuple[str, torch.Tensor]],
*,
mapper: Optional[WeightsMapper] = None,
) -> None:
if mapper is not None:
weights = mapper.apply(weights)
self._load_module("", self.module, weights)
def init_vllm_registered_model( def init_vllm_registered_model(