[Bugfix] fix composite weight loading and EAGLE weight loading (#9160)

This commit is contained in:
Cyrus Leung 2024-10-09 15:36:55 +08:00 committed by GitHub
parent 0b5b5d767e
commit 8bfaa4e31e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 241 additions and 361 deletions

View File

@ -13,7 +13,6 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors, SequenceData
@ -21,7 +20,7 @@ from vllm.sequence import IntermediateTensors, SequenceData
from .blip import (BlipVisionModel, dummy_image_for_blip,
get_max_blip_image_tokens)
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (group_weights_with_prefix, init_vllm_registered_model,
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
merge_multimodal_embeddings)
# We use this internally as placeholders since there is no image token
@ -687,35 +686,5 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
weights_group = group_weights_with_prefix(weights)
# load vision encoder
self.vision_model.load_weights(weights_group["vision_model"])
# load query tokens
for name, loaded_weight in weights_group["query_tokens"]:
assert name == ""
param = self.query_tokens
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load qformer
qformer_params_dict = dict(self.qformer.named_parameters())
for name, loaded_weight in weights_group["qformer"]:
param = qformer_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load mlp projector
mlp_params_dict = dict(self.language_projection.named_parameters())
for name, loaded_weight in weights_group["language_projection"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])
loader = AutoWeightsLoader(self)
loader.load_weights(weights)

View File

@ -31,7 +31,6 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
@ -42,8 +41,7 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, group_weights_with_prefix,
merge_multimodal_embeddings)
from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings
# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011
@ -349,16 +347,5 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
weights_group = group_weights_with_prefix(weights)
# load vision embeddings
vision_params_dict = dict(self.vision_embed_tokens.named_parameters())
for name, loaded_weight in weights_group["vision_embed_tokens"]:
param = vision_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])
loader = AutoWeightsLoader(self)
loader.load_weights(weights)

View File

@ -40,7 +40,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (group_weights_with_prefix, is_pp_missing_parameter,
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
logger = init_logger(__name__)
@ -447,19 +447,9 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weights_group = group_weights_with_prefix(weights)
self.model.load_weights(weights_group["model"])
if not self.config.tie_word_embeddings:
# NOTE: For now self.lm_head is not defined because
# tie_word_embeddings is assumed to the False
lm_head_dict = dict(self.lm_head.named_parameters())
for name, loaded_weight in weights_group["lm_head"]:
if is_pp_missing_parameter(name, self.lm_head):
continue
param = lm_head_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
loader.load_weights(weights)

View File

@ -20,7 +20,6 @@ from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.intern_vit import InternVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
@ -32,8 +31,8 @@ from vllm.utils import is_list_of
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_num_patches)
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, group_weights_with_prefix,
init_vllm_registered_model, merge_multimodal_embeddings)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
merge_multimodal_embeddings)
IMG_START = '<img>'
IMG_END = '</img>'
@ -609,19 +608,5 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
weights_group = group_weights_with_prefix(weights)
# load vision encoder
self.vision_model.load_weights(weights_group["vision_model"])
# load mlp projector
mlp_params_dict = dict(self.mlp1.named_parameters())
for name, loaded_weight in weights_group["mlp1"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])
loader = AutoWeightsLoader(self)
loader.load_weights(weights)

View File

@ -51,8 +51,7 @@ from vllm.sequence import IntermediateTensors
from vllm.utils import is_hip
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, group_weights_with_prefix,
is_pp_missing_parameter,
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
@ -564,25 +563,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weights = [
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
loader.load_weights(
self.maybe_remap_mistral(name, loaded_weight)
for name, loaded_weight in weights
]
weights_group = group_weights_with_prefix(weights)
self.model.load_weights(weights_group["model"])
if not self.config.tie_word_embeddings:
lm_head_dict = dict(self.lm_head.named_parameters())
for name, loaded_weight in weights_group["lm_head"]:
if is_pp_missing_parameter(name, self.lm_head):
continue
param = lm_head_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
for name, loaded_weight in weights)
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path)

View File

@ -13,7 +13,6 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors
@ -26,8 +25,8 @@ from .interfaces import SupportsMultiModal, SupportsPP
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
input_processor_for_siglip)
from .utils import (flatten_bn, group_weights_with_prefix,
init_vllm_registered_model, merge_multimodal_embeddings)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
merge_multimodal_embeddings)
class LlavaImagePixelInputs(TypedDict):
@ -406,19 +405,5 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
weights_group = group_weights_with_prefix(weights)
# load vision encoder
self.vision_tower.load_weights(weights_group["vision_tower"])
# load mlp projector
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])
loader = AutoWeightsLoader(self)
loader.load_weights(weights)

View File

@ -15,7 +15,6 @@ from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors
@ -29,8 +28,8 @@ from .llava import LlavaMultiModalProjector
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip)
from .utils import (flatten_bn, group_weights_with_prefix,
init_vllm_registered_model, merge_multimodal_embeddings)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
merge_multimodal_embeddings)
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
@ -642,27 +641,5 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
weights_group = group_weights_with_prefix(weights)
# load vision encoder
self.vision_tower.load_weights(weights_group["vision_tower"])
# load mlp projector
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load newline
for name, loaded_weight in weights_group["image_newline"]:
assert name == ""
param = self.image_newline
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])
loader = AutoWeightsLoader(self)
loader.load_weights(weights)

View File

@ -15,7 +15,6 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
@ -28,7 +27,7 @@ from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsMultiModal, SupportsPP
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip)
from .utils import (group_weights_with_prefix, init_vllm_registered_model,
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
merge_multimodal_embeddings)
# For profile run
@ -458,19 +457,9 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
weights_group = group_weights_with_prefix(weights)
# load vision encoder
self.vision_tower.load_weights(weights_group["vision_tower"])
# load mlp projector
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])
loader = AutoWeightsLoader(
self,
# This model doesn't support images for now
ignore_unexpected_prefixes=["image_newline"],
)
loader.load_weights(weights)

View File

@ -20,7 +20,6 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import (cached_get_tokenizer,
@ -35,8 +34,8 @@ from .interfaces import SupportsMultiModal, SupportsPP
from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip,
dummy_video_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip)
from .utils import (flatten_bn, group_weights_with_prefix,
init_vllm_registered_model, merge_multimodal_embeddings)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
merge_multimodal_embeddings)
logger = init_logger(__name__)
@ -872,19 +871,5 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
weights_group = group_weights_with_prefix(weights)
# load vision encoder
self.vision_tower.load_weights(weights_group["vision_tower"])
# load mlp projector
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])
loader = AutoWeightsLoader(self)
loader.load_weights(weights)

View File

@ -11,7 +11,6 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.gemma import GemmaForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
@ -21,7 +20,7 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal, SupportsPP
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
from .utils import group_weights_with_prefix, merge_multimodal_embeddings
from .utils import AutoWeightsLoader, merge_multimodal_embeddings
logger = init_logger(__name__)
@ -292,19 +291,5 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
weights_group = group_weights_with_prefix(weights)
# load vision tower
self.vision_tower.load_weights(weights_group["vision_tower"])
# load mlp projector
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])
loader = AutoWeightsLoader(self)
loader.load_weights(weights)

View File

@ -31,7 +31,6 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
@ -42,15 +41,11 @@ from vllm.utils import is_list_of
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, group_weights_with_prefix,
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
merge_multimodal_embeddings)
logger = init_logger(__name__)
_KEYS_TO_MODIFY_MAPPING = {
"model.vision_embed_tokens": "vision_embed_tokens",
}
# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 32044
@ -295,35 +290,8 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
return image_features_hd_newline
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
weights_group = group_weights_with_prefix(weights)
# load vision encoder
self.img_processor.load_weights(weights_group["img_processor"])
# load glb_GN
for name, loaded_weight in weights_group["glb_GN"]:
assert name == ""
param = self.glb_GN
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load sub_GN
for name, loaded_weight in weights_group["sub_GN"]:
assert name == ""
param = self.sub_GN
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load mlp projector
mlp_params_dict = dict(self.img_projection.named_parameters())
for name, loaded_weight in weights_group["img_projection"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loader = AutoWeightsLoader(self)
loader.load_weights(weights)
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
@ -715,27 +683,12 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
hf_to_vllm_mapping = {
"model.vision_embed_tokens.": "vision_embed_tokens.",
"lm_head.": "language_model.lm_head.",
"model.": "language_model.model.",
}
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"model.vision_embed_tokens.": "vision_embed_tokens.",
"lm_head.": "language_model.lm_head.",
"model.": "language_model.model.",
})
def hf_to_vllm_name(key: str) -> str:
for hf_name, vllm_name in hf_to_vllm_mapping.items():
if key.startswith(hf_name):
return key.replace(hf_name, vllm_name, 1)
return key
vllm_weights = {hf_to_vllm_name(k): v for k, v in weights}
# prepare weight iterators for components
weights_group = group_weights_with_prefix(vllm_weights.items())
# load vision embeddings and encoder
self.vision_embed_tokens.load_weights(
weights_group["vision_embed_tokens"])
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])
loader = AutoWeightsLoader(self)
loader.load_weights(weights, mapper=hf_to_vllm_mapper)

View File

@ -48,8 +48,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, group_weights_with_prefix,
is_pp_missing_parameter,
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
@ -435,17 +434,9 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weights_group = group_weights_with_prefix(weights)
self.model.load_weights(weights_group["model"])
if not self.config.tie_word_embeddings:
lm_head_dict = dict(self.lm_head.named_parameters())
for name, loaded_weight in weights_group["lm_head"]:
if is_pp_missing_parameter(name, self.lm_head):
continue
param = lm_head_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
loader.load_weights(weights)

View File

@ -16,13 +16,12 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsPP
from .qwen2 import Qwen2Model
from .utils import group_weights_with_prefix
from .utils import AutoWeightsLoader
class ReLU(nn.Module):
@ -120,13 +119,5 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP):
return self._pooler(hidden_states, pooling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weights_group = group_weights_with_prefix(weights)
self.model.load_weights(weights_group["model"])
score_dict = dict(self.score.named_parameters())
for name, loaded_weight in weights_group["score"]:
param = score_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loader = AutoWeightsLoader(self)
loader.load_weights(weights)

View File

@ -25,11 +25,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import (flatten_bn,
group_weights_with_prefix,
init_vllm_registered_model,
merge_multimodal_embeddings)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs, NestedTensors
@ -41,6 +36,8 @@ from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, merge_multimodal_embeddings)
_AUDIO_PLACEHOLDER_TOKEN = 128002
_AUDIO_TOKENS_PER_SECOND = 6.25
@ -498,30 +495,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
weights_group = group_weights_with_prefix(weights)
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."})
# load audio tower weights
audio_tower_weights = weights_group["audio_tower"]
audio_tower_params_dict = dict(
self.audio_tower.named_parameters(
prefix=self.audio_tower.base_model_prefix))
for name, loaded_weight in audio_tower_weights:
if name in audio_tower_params_dict:
param = audio_tower_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load projector weights
projector_weights = weights_group["multi_modal_projector"]
projector_params_dict = dict(
self.multi_modal_projector.named_parameters())
for name, loaded_weight in projector_weights:
param = projector_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])
loader = AutoWeightsLoader(self,
ignore_unexpected_prefixes=["audio_tower."])
loader.load_weights(weights, mapper=hf_to_vllm_mapper)

View File

@ -1,7 +1,7 @@
import itertools
from collections import UserDict
from typing import (Any, Dict, Iterable, List, Literal, Optional, Protocol,
Tuple, Union, overload)
from dataclasses import dataclass, field
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
Protocol, Tuple, Union, overload)
import torch
import torch.nn as nn
@ -12,55 +12,184 @@ from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig,
SchedulerConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.loader import build_model
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import ModelRegistry
from vllm.multimodal.base import NestedTensors
from vllm.sequence import IntermediateTensors
from vllm.utils import is_pin_memory_available
WeightsMapping = Mapping[str, Optional[str]]
"""If a key maps to a value of `None`, the corresponding weight is ignored."""
class WeightsGroup(UserDict):
@dataclass
class WeightsMapper:
"""Maps the name of each weight if they match the following patterns."""
orig_to_new_substr: WeightsMapping = field(default_factory=dict)
orig_to_new_prefix: WeightsMapping = field(default_factory=dict)
orig_to_new_suffix: WeightsMapping = field(default_factory=dict)
def _map_name(self, key: str) -> Optional[str]:
for substr, new_key in self.orig_to_new_substr.items():
if substr in key:
if new_key is None:
return None
key = key.replace(substr, new_key, 1)
for prefix, new_key in self.orig_to_new_prefix.items():
if key.startswith(prefix):
if new_key is None:
return None
key = key.replace(prefix, new_key, 1)
for suffix, new_key in self.orig_to_new_suffix.items():
if key.endswith(suffix):
if new_key is None:
return None
key = new_key.join(key.rsplit(suffix, 1))
return key
def apply(
self, weights: Iterable[Tuple[str, torch.Tensor]]
) -> Iterable[Tuple[str, torch.Tensor]]:
return ((out_name, data) for name, data in weights
if (out_name := self._map_name(name)) is not None)
class AutoWeightsLoader:
"""
Wraps grouped weights dictionary for a more informative error message
when attempting to access a weight component that does not exist.
Helper class to load weights into a :class:`torch.nn.Module`. It is able
to automatically detect child modules and parameters while iterating over
the weights only once.
The weight loading logic for individual modules can be overridden
by defining a ``load_weights`` method.
Similarly, the weight loading logic for individual parameters can be
overridden by defining a ``weight_loader`` method.
"""
def __getitem__(self, key: str) -> Iterable[Tuple[str, torch.Tensor]]:
try:
return super().__getitem__(key)
except KeyError as exc:
msg = (f"There is no weights named with the prefix: {key}. "
f"Available prefix: {set(self.keys())}")
raise KeyError(msg) from exc
def __init__(
self,
module: nn.Module,
*,
skip_prefixes: Optional[List[str]] = None,
ignore_unexpected_prefixes: Optional[List[str]] = None,
) -> None:
super().__init__()
self.module = module
self.skip_prefixes = skip_prefixes or []
self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or []
def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]],
prefix: str) -> Iterable[Tuple[str, torch.Tensor]]:
"""
Helper function to load weights for inner vLLM models.
def _groupby_prefix(
self,
weights: Iterable[Tuple[str, torch.Tensor]],
) -> Iterable[Tuple[str, Iterable[Tuple[str, torch.Tensor]]]]:
weights_by_parts = ((weight_name.split(".", 1), weight_data)
for weight_name, weight_data in weights)
See also:
:ref:`init_vllm_registered_model`
"""
for name, loaded_weight in weights:
name = name.split(".")
if prefix == name.pop(0):
name = ".".join(name)
yield name, loaded_weight
for prefix, group in itertools.groupby(weights_by_parts,
key=lambda x: x[0][0]):
yield (
prefix,
# Because maxsplit=1 in weight_name.split(...),
# the length of `parts` must either be 1 or 2
(("" if len(parts) == 1 else parts[1], weights_data)
for parts, weights_data in group),
)
def _get_qualname(self, prefix: str, rest: str) -> str:
if prefix == "":
return rest
if rest == "":
return prefix
def group_weights_with_prefix(
weights: Iterable[Tuple[str, torch.Tensor]], ) -> WeightsGroup:
"""
Helper function to group weights with prefix
"""
init_weights, repeated_weights = itertools.tee(weights, 2)
weights_prefix = {name.split(".")[0] for name, _ in init_weights}
repeated_weights = itertools.tee(repeated_weights, len(weights_prefix))
return ".".join((prefix, rest))
return WeightsGroup({
prefix: filter_weights(component, prefix)
for component, prefix in zip(repeated_weights, weights_prefix)
})
def _can_skip(self, qualname: str) -> bool:
return any(qualname.startswith(p) for p in self.skip_prefixes)
def _can_ignore_unexpected(self, qualname: str) -> bool:
return any(
qualname.startswith(p) for p in self.ignore_unexpected_prefixes)
def _load_param(
self,
base_prefix: str,
param: nn.Parameter,
weights: Iterable[Tuple[str, torch.Tensor]],
) -> None:
for weight_name, weight_data in weights:
weight_qualname = self._get_qualname(base_prefix, weight_name)
if self._can_skip(weight_qualname):
continue
if weight_name != "":
if not self._can_ignore_unexpected(weight_qualname):
raise ValueError(
f"Attempted to load nested weight '{weight_qualname}' "
f"into a single parameter '{base_prefix}'")
continue
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, weight_data)
def _load_module(
self,
base_prefix: str,
module: nn.Module,
weights: Iterable[Tuple[str, torch.Tensor]],
) -> None:
if isinstance(module, PPMissingLayer):
return
# Avoid infinite recursion since this function is typically
# called inside load_weights of the module itself
if module != self.module:
module_load_weights = getattr(module, "load_weights", None)
if callable(module_load_weights):
module_load_weights(weights)
return
child_modules = dict(module.named_children())
child_params = dict(module.named_parameters(recurse=False))
for child_prefix, child_weights in self._groupby_prefix(weights):
prefix = self._get_qualname(base_prefix, child_prefix)
if self._can_skip(prefix):
continue
if child_prefix in child_modules:
self._load_module(prefix, child_modules[child_prefix],
child_weights)
elif child_prefix in child_params:
self._load_param(prefix, child_params[child_prefix],
child_weights)
else:
if not self._can_ignore_unexpected(prefix):
msg = f"There is no module or parameter named '{prefix}'"
raise ValueError(msg)
def load_weights(
self,
weights: Iterable[Tuple[str, torch.Tensor]],
*,
mapper: Optional[WeightsMapper] = None,
) -> None:
if mapper is not None:
weights = mapper.apply(weights)
self._load_module("", self.module, weights)
def init_vllm_registered_model(