mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-20 17:47:06 +08:00
[Model] Composite weight loading for multimodal Qwen2 (#10944)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
b26b4cd03c
commit
bf0e382e16
@ -2472,7 +2472,15 @@ class VllmConfig:
|
|||||||
return quant_config
|
return quant_config
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def with_hf_config(self, hf_config: PretrainedConfig) -> "VllmConfig":
|
def with_hf_config(
|
||||||
|
self,
|
||||||
|
hf_config: PretrainedConfig,
|
||||||
|
architectures: Optional[list[str]] = None,
|
||||||
|
) -> "VllmConfig":
|
||||||
|
if architectures is not None:
|
||||||
|
hf_config = copy.deepcopy(hf_config)
|
||||||
|
hf_config.architectures = architectures
|
||||||
|
|
||||||
model_config = copy.deepcopy(self.model_config)
|
model_config = copy.deepcopy(self.model_config)
|
||||||
model_config.hf_config = hf_config
|
model_config.hf_config = hf_config
|
||||||
|
|
||||||
|
|||||||
@ -101,12 +101,10 @@ def _initialize_model(
|
|||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
*,
|
*,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
architectures: Optional[list[str]] = None,
|
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
"""Initialize a model with the given configurations."""
|
"""Initialize a model with the given configurations."""
|
||||||
model_config = vllm_config.model_config
|
model_config = vllm_config.model_config
|
||||||
model_class, _ = get_model_architecture(model_config,
|
model_class, _ = get_model_architecture(model_config)
|
||||||
architectures=architectures)
|
|
||||||
|
|
||||||
signatures = inspect.signature(model_class.__init__)
|
signatures = inspect.signature(model_class.__init__)
|
||||||
all_params = [param.name for param in signatures.parameters.values()]
|
all_params = [param.name for param in signatures.parameters.values()]
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
"""Utilities for selecting and loading models."""
|
"""Utilities for selecting and loading models."""
|
||||||
import contextlib
|
import contextlib
|
||||||
from typing import Optional, Tuple, Type
|
from typing import Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -20,12 +20,8 @@ def set_default_torch_dtype(dtype: torch.dtype):
|
|||||||
|
|
||||||
|
|
||||||
def get_model_architecture(
|
def get_model_architecture(
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
|
||||||
*,
|
architectures = getattr(model_config.hf_config, "architectures", [])
|
||||||
architectures: Optional[list[str]] = None,
|
|
||||||
) -> Tuple[Type[nn.Module], str]:
|
|
||||||
if architectures is None:
|
|
||||||
architectures = getattr(model_config.hf_config, "architectures", [])
|
|
||||||
|
|
||||||
# Special handling for quantized Mixtral.
|
# Special handling for quantized Mixtral.
|
||||||
# FIXME(woosuk): This is a temporary hack.
|
# FIXME(woosuk): This is a temporary hack.
|
||||||
|
|||||||
@ -444,14 +444,17 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
self.model = Qwen2Model(vllm_config=vllm_config,
|
self.model = Qwen2Model(vllm_config=vllm_config,
|
||||||
prefix=maybe_prefix(prefix, "model"))
|
prefix=maybe_prefix(prefix, "model"))
|
||||||
|
|
||||||
if config.tie_word_embeddings:
|
if get_pp_group().is_last_rank:
|
||||||
self.lm_head = self.model.embed_tokens
|
if config.tie_word_embeddings:
|
||||||
|
self.lm_head = self.model.embed_tokens
|
||||||
|
else:
|
||||||
|
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=maybe_prefix(
|
||||||
|
prefix, "lm_head"))
|
||||||
else:
|
else:
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
self.lm_head = PPMissingLayer()
|
||||||
config.hidden_size,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=maybe_prefix(
|
|
||||||
prefix, "lm_head"))
|
|
||||||
|
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
self.sampler = get_sampler()
|
self.sampler = get_sampler()
|
||||||
|
|||||||
@ -19,7 +19,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
|
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
|
||||||
from functools import lru_cache
|
from functools import cached_property, lru_cache
|
||||||
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
|
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
|
||||||
Union)
|
Union)
|
||||||
|
|
||||||
@ -34,12 +34,7 @@ from vllm.config import VllmConfig
|
|||||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||||
InputContext, token_inputs)
|
InputContext, token_inputs)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
|
||||||
from vllm.model_executor.model_loader.weight_utils import (
|
|
||||||
default_weight_loader, maybe_remap_kv_scale_name)
|
|
||||||
from vllm.model_executor.models.qwen2 import Qwen2Model
|
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||||
from vllm.multimodal.inputs import NestedTensors
|
from vllm.multimodal.inputs import NestedTensors
|
||||||
@ -47,15 +42,11 @@ from vllm.multimodal.utils import consecutive_placeholder_ranges
|
|||||||
from vllm.sequence import IntermediateTensors, SequenceData
|
from vllm.sequence import IntermediateTensors, SequenceData
|
||||||
|
|
||||||
from .interfaces import SupportsMultiModal, SupportsPP
|
from .interfaces import SupportsMultiModal, SupportsPP
|
||||||
from .utils import merge_multimodal_embeddings
|
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||||
|
maybe_prefix, merge_multimodal_embeddings)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_KEYS_TO_MODIFY_MAPPING = {
|
|
||||||
"language_model.lm_head": "lm_head",
|
|
||||||
"language_model.model": "language_model",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# # === Audio Inputs === #
|
# # === Audio Inputs === #
|
||||||
class Qwen2AudioInputs(TypedDict):
|
class Qwen2AudioInputs(TypedDict):
|
||||||
@ -281,25 +272,23 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
|
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
self.language_model = Qwen2Model(
|
self.language_model = init_vllm_registered_model(
|
||||||
vllm_config=vllm_config.with_hf_config(config.text_config),
|
vllm_config=vllm_config,
|
||||||
prefix=prefix)
|
hf_config=config.text_config,
|
||||||
self.unpadded_vocab_size = config.text_config.vocab_size
|
prefix=maybe_prefix(prefix, "language_model"),
|
||||||
if config.text_config.tie_word_embeddings:
|
architectures=["Qwen2ForCausalLM"],
|
||||||
self.lm_head = self.language_model.embed_tokens
|
)
|
||||||
else:
|
|
||||||
self.lm_head = ParallelLMHead(config.text_config.vocab_size,
|
|
||||||
config.text_config.hidden_size,
|
|
||||||
quant_config=quant_config)
|
|
||||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
|
||||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
|
||||||
config.text_config.vocab_size,
|
|
||||||
logit_scale)
|
|
||||||
self.sampler = get_sampler()
|
|
||||||
|
|
||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
self.language_model.make_empty_intermediate_tensors)
|
self.language_model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def sampler(self):
|
||||||
|
if hasattr(self.language_model, "sampler"):
|
||||||
|
return self.language_model.sampler
|
||||||
|
|
||||||
|
return get_sampler()
|
||||||
|
|
||||||
def _validate_and_reshape_mm_tensor(self,
|
def _validate_and_reshape_mm_tensor(self,
|
||||||
mm_input: Union[torch.Tensor,
|
mm_input: Union[torch.Tensor,
|
||||||
List[torch.Tensor]],
|
List[torch.Tensor]],
|
||||||
@ -414,72 +403,30 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
multimodal_embeddings)
|
multimodal_embeddings)
|
||||||
input_ids = None
|
input_ids = None
|
||||||
|
|
||||||
hidden_states = self.language_model(input_ids,
|
hidden_states = self.language_model.model(input_ids,
|
||||||
positions,
|
positions,
|
||||||
kv_caches,
|
kv_caches,
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
intermediate_tensors,
|
intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def compute_logits(self, hidden_states: torch.Tensor,
|
def compute_logits(
|
||||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
self,
|
||||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
sampling_metadata)
|
sampling_metadata: SamplingMetadata,
|
||||||
return logits
|
) -> Optional[torch.Tensor]:
|
||||||
|
return self.language_model.compute_logits(hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
next_tokens = self.sampler(logits, sampling_metadata)
|
return self.language_model.sample(logits, sampling_metadata)
|
||||||
return next_tokens
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str,
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
torch.Tensor]]) -> Set[str]:
|
torch.Tensor]]) -> Set[str]:
|
||||||
stacked_params_mapping = [
|
loader = AutoWeightsLoader(self)
|
||||||
# (param_name, shard_name, shard_id)
|
return loader.load_weights(weights)
|
||||||
("qkv_proj", "q_proj", "q"),
|
|
||||||
("qkv_proj", "k_proj", "k"),
|
|
||||||
("qkv_proj", "v_proj", "v"),
|
|
||||||
("gate_up_proj", "gate_proj", 0),
|
|
||||||
("gate_up_proj", "up_proj", 1),
|
|
||||||
]
|
|
||||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
|
||||||
loaded_params: Set[str] = set()
|
|
||||||
for name, loaded_weight in weights:
|
|
||||||
if "rotary_emb.inv_freq" in name:
|
|
||||||
continue
|
|
||||||
if (self.config.text_config.tie_word_embeddings
|
|
||||||
and "lm_head.weight" in name):
|
|
||||||
continue
|
|
||||||
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
|
|
||||||
if key_to_modify in name:
|
|
||||||
name = name.replace(key_to_modify, new_key)
|
|
||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
|
||||||
if weight_name not in name or 'audio' in name:
|
|
||||||
continue
|
|
||||||
name = name.replace(weight_name, param_name)
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = param.weight_loader
|
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
# Remapping the name of FP8 kv-scale.
|
|
||||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
|
||||||
if name is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = getattr(param, "weight_loader",
|
|
||||||
default_weight_loader)
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
loaded_params.add(name)
|
|
||||||
return loaded_params
|
|
||||||
|
|||||||
@ -21,7 +21,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
|
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
|
||||||
from functools import partial
|
from functools import cached_property, partial
|
||||||
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
|
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
|
||||||
Optional, Set, Tuple, Type, TypedDict, Union)
|
Optional, Set, Tuple, Type, TypedDict, Union)
|
||||||
|
|
||||||
@ -40,7 +40,7 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
|
|||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
from vllm.attention import AttentionMetadata
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import get_pp_group, parallel_state
|
from vllm.distributed import parallel_state
|
||||||
from vllm.distributed import utils as dist_utils
|
from vllm.distributed import utils as dist_utils
|
||||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||||
InputContext, token_inputs)
|
InputContext, token_inputs)
|
||||||
@ -49,15 +49,12 @@ from vllm.model_executor import SamplingMetadata
|
|||||||
from vllm.model_executor.layers.activation import QuickGELU
|
from vllm.model_executor.layers.activation import QuickGELU
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||||
GPTQMarlinConfig)
|
GPTQMarlinConfig)
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.qwen2 import Qwen2Model
|
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.image import cached_get_image_processor
|
from vllm.multimodal.image import cached_get_image_processor
|
||||||
from vllm.multimodal.inputs import (MultiModalData, MultiModalDataDict,
|
from vllm.multimodal.inputs import (MultiModalData, MultiModalDataDict,
|
||||||
@ -69,9 +66,8 @@ from vllm.transformers_utils.config import uses_mrope
|
|||||||
from vllm.transformers_utils.processor import cached_get_processor
|
from vllm.transformers_utils.processor import cached_get_processor
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
|
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
|
||||||
from .utils import (PPMissingLayer, get_vit_attn_backend,
|
from .utils import (AutoWeightsLoader, WeightsMapper, get_vit_attn_backend,
|
||||||
is_pp_missing_parameter,
|
init_vllm_registered_model, maybe_prefix)
|
||||||
make_empty_intermediate_tensors_factory, maybe_prefix)
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -506,6 +502,8 @@ class Qwen2VisionTransformer(nn.Module):
|
|||||||
mlp_ratio: float = vision_config.mlp_ratio
|
mlp_ratio: float = vision_config.mlp_ratio
|
||||||
|
|
||||||
self.spatial_merge_size = spatial_merge_size
|
self.spatial_merge_size = spatial_merge_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
self.patch_embed = Qwen2VisionPatchEmbed(
|
self.patch_embed = Qwen2VisionPatchEmbed(
|
||||||
patch_size=patch_size,
|
patch_size=patch_size,
|
||||||
@ -595,6 +593,53 @@ class Qwen2VisionTransformer(nn.Module):
|
|||||||
x = self.merger(x)
|
x = self.merger(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||||
|
loaded_params: Set[str] = set()
|
||||||
|
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if name.endswith("qkv.weight"):
|
||||||
|
visual_num_heads = self.num_heads
|
||||||
|
visual_embed_dim = self.embed_dim
|
||||||
|
head_size = visual_embed_dim // visual_num_heads
|
||||||
|
loaded_weight = loaded_weight.view(3, visual_num_heads,
|
||||||
|
head_size,
|
||||||
|
visual_embed_dim)
|
||||||
|
loaded_weight = loaded_weight.transpose(0, 1)
|
||||||
|
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
|
||||||
|
elif name.endswith("qkv.bias"):
|
||||||
|
visual_num_heads = self.num_heads
|
||||||
|
visual_embed_dim = self.embed_dim
|
||||||
|
head_size = visual_embed_dim // visual_num_heads
|
||||||
|
loaded_weight = loaded_weight.view(3, visual_num_heads,
|
||||||
|
head_size)
|
||||||
|
loaded_weight = loaded_weight.transpose(0, 1)
|
||||||
|
loaded_weight = loaded_weight.reshape(-1)
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
# === Vision input helpers === #
|
# === Vision input helpers === #
|
||||||
|
|
||||||
@ -1082,27 +1127,21 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
prefix=maybe_prefix(prefix, "visual"),
|
prefix=maybe_prefix(prefix, "visual"),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model = Qwen2Model(vllm_config=vllm_config,
|
self.language_model = init_vllm_registered_model(
|
||||||
prefix=maybe_prefix(prefix, "model"))
|
vllm_config=vllm_config,
|
||||||
|
prefix=maybe_prefix(prefix, "language_model"),
|
||||||
if get_pp_group().is_last_rank:
|
architectures=["Qwen2ForCausalLM"],
|
||||||
if config.tie_word_embeddings:
|
)
|
||||||
self.lm_head = self.model.embed_tokens
|
|
||||||
else:
|
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
|
||||||
config.hidden_size,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=maybe_prefix(
|
|
||||||
prefix, "lm_head"))
|
|
||||||
else:
|
|
||||||
self.lm_head = PPMissingLayer()
|
|
||||||
|
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
|
||||||
self.sampler = get_sampler()
|
|
||||||
|
|
||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
make_empty_intermediate_tensors_factory(
|
self.language_model.make_empty_intermediate_tensors)
|
||||||
["hidden_states", "residual"], config.hidden_size))
|
|
||||||
|
@cached_property
|
||||||
|
def sampler(self):
|
||||||
|
if hasattr(self.language_model, "sampler"):
|
||||||
|
return self.language_model.sampler
|
||||||
|
|
||||||
|
return get_sampler()
|
||||||
|
|
||||||
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
|
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
|
||||||
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
|
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
|
||||||
@ -1261,7 +1300,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
multimodal_embeddings: Optional[List[Tuple[NestedTensors,
|
multimodal_embeddings: Optional[List[Tuple[NestedTensors,
|
||||||
str]]] = None,
|
str]]] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||||
if multimodal_embeddings is not None:
|
if multimodal_embeddings is not None:
|
||||||
for embeddings, modality in multimodal_embeddings:
|
for embeddings, modality in multimodal_embeddings:
|
||||||
if modality == "image":
|
if modality == "image":
|
||||||
@ -1330,7 +1369,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
multimodal_embeddings)
|
multimodal_embeddings)
|
||||||
input_ids = None
|
input_ids = None
|
||||||
|
|
||||||
hidden_states = self.model(
|
hidden_states = self.language_model.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
kv_caches=kv_caches,
|
kv_caches=kv_caches,
|
||||||
@ -1340,80 +1379,28 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
)
|
)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def compute_logits(self, hidden_states: torch.Tensor,
|
def compute_logits(
|
||||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
self,
|
||||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
hidden_states: torch.Tensor,
|
||||||
sampling_metadata)
|
sampling_metadata: SamplingMetadata,
|
||||||
return logits
|
) -> Optional[torch.Tensor]:
|
||||||
|
return self.language_model.compute_logits(hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
next_tokens = self.sampler(logits, sampling_metadata)
|
return self.language_model.sample(logits, sampling_metadata)
|
||||||
return next_tokens
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str,
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
torch.Tensor]]) -> Set[str]:
|
torch.Tensor]]) -> Set[str]:
|
||||||
stacked_params_mapping = [
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
# (param_name, shard_name, shard_id)
|
orig_to_new_prefix={
|
||||||
("qkv_proj", "q_proj", "q"),
|
"lm_head.": "language_model.lm_head.",
|
||||||
("qkv_proj", "k_proj", "k"),
|
"model.": "language_model.model.",
|
||||||
("qkv_proj", "v_proj", "v"),
|
})
|
||||||
("gate_up_proj", "up_proj", 1),
|
|
||||||
("gate_up_proj", "gate_proj", 0),
|
|
||||||
]
|
|
||||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
|
||||||
loaded_params: Set[str] = set()
|
|
||||||
for name, loaded_weight in weights:
|
|
||||||
if "rotary_emb.inv_freq" in name:
|
|
||||||
continue
|
|
||||||
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
|
||||||
continue
|
|
||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
|
||||||
if weight_name not in name:
|
|
||||||
continue
|
|
||||||
name = name.replace(weight_name, param_name)
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = param.weight_loader
|
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
if "visual" in name and name.endswith("qkv.weight"):
|
|
||||||
visual_num_heads = self.config.vision_config.num_heads
|
|
||||||
visual_embed_dim = self.config.vision_config.embed_dim
|
|
||||||
head_size = visual_embed_dim // visual_num_heads
|
|
||||||
loaded_weight = loaded_weight.view(3, visual_num_heads,
|
|
||||||
head_size,
|
|
||||||
visual_embed_dim)
|
|
||||||
loaded_weight = loaded_weight.transpose(0, 1)
|
|
||||||
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
|
|
||||||
elif "visual" in name and name.endswith("qkv.bias"):
|
|
||||||
visual_num_heads = self.config.vision_config.num_heads
|
|
||||||
visual_embed_dim = self.config.vision_config.embed_dim
|
|
||||||
head_size = visual_embed_dim // visual_num_heads
|
|
||||||
loaded_weight = loaded_weight.view(3, visual_num_heads,
|
|
||||||
head_size)
|
|
||||||
loaded_weight = loaded_weight.transpose(0, 1)
|
|
||||||
loaded_weight = loaded_weight.reshape(-1)
|
|
||||||
try:
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
except KeyError:
|
|
||||||
raise ValueError(f"Unexpected weight: {name}") from None
|
|
||||||
|
|
||||||
weight_loader = getattr(param, "weight_loader",
|
loader = AutoWeightsLoader(self)
|
||||||
default_weight_loader)
|
return loader.load_weights(weights, mapper=hf_to_vllm_mapper)
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
loaded_params.add(name)
|
|
||||||
return loaded_params
|
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors
|
from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors
|
||||||
from vllm.platforms import _Backend, current_platform
|
from vllm.platforms import _Backend, current_platform
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import is_pin_memory_available
|
from vllm.utils import is_pin_memory_available, print_warning_once
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -251,12 +251,15 @@ def init_vllm_registered_model(
|
|||||||
"""
|
"""
|
||||||
from vllm.model_executor.model_loader.loader import _initialize_model
|
from vllm.model_executor.model_loader.loader import _initialize_model
|
||||||
|
|
||||||
if hf_config is not None:
|
if hf_config is None and architectures is not None:
|
||||||
vllm_config = vllm_config.with_hf_config(hf_config)
|
# So that the architectures field is overridden
|
||||||
|
hf_config = vllm_config.model_config.hf_config
|
||||||
|
|
||||||
return _initialize_model(vllm_config=vllm_config,
|
if hf_config is not None:
|
||||||
prefix=prefix,
|
vllm_config = vllm_config.with_hf_config(hf_config,
|
||||||
architectures=architectures)
|
architectures=architectures)
|
||||||
|
|
||||||
|
return _initialize_model(vllm_config=vllm_config, prefix=prefix)
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@ -592,7 +595,7 @@ def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
|
|||||||
if is_flash_attn_2_available():
|
if is_flash_attn_2_available():
|
||||||
selected_backend = _Backend.FLASH_ATTN
|
selected_backend = _Backend.FLASH_ATTN
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
print_warning_once(
|
||||||
"Current `vllm-flash-attn` has a bug inside vision module, "
|
"Current `vllm-flash-attn` has a bug inside vision module, "
|
||||||
"so we use xformers backend instead. You can run "
|
"so we use xformers backend instead. You can run "
|
||||||
"`pip install flash-attn` to use flash-attention backend.")
|
"`pip install flash-attn` to use flash-attention backend.")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user