[Model] Clean up MiniCPMV (#10751)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2024-11-29 12:47:06 +08:00 committed by GitHub
parent c83919c7a6
commit fa6ecb9aa7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 149 additions and 215 deletions

View File

@ -295,16 +295,29 @@ VLM_TEST_SETTINGS = {
)
],
),
"minicpmv": VLMTestInfo(
"minicpmv_25": VLMTestInfo(
models=["openbmb/MiniCPM-Llama3-V-2_5"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
test_type=VLMTestType.IMAGE,
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
max_model_len=4096,
max_num_seqs=2,
get_stop_token_ids=lambda tok: [tok.eos_id, tok.eot_id],
postprocess_inputs=model_utils.wrap_inputs_post_processor,
hf_output_post_proc=model_utils.minicmpv_trunc_hf_output,
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
),
"minicpmv_26": VLMTestInfo(
models=["openbmb/MiniCPM-V-2_6"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
img_idx_to_prompt=lambda idx: "(<image>./</image>)\n",
max_model_len=4096,
max_num_seqs=2,
get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501
postprocess_inputs=model_utils.ignore_inputs_post_processor(
"image_sizes"
),
hf_output_post_proc=model_utils.minicpmv_trunc_hf_output,
),
# Tests for phi3v currently live in another file because of a bug in
# transformers. Once this issue is fixed, we can enable them here instead.

View File

@ -170,7 +170,7 @@ def paligemma_vllm_to_hf_output(vllm_output: RunnerOutput,
####### Post-processors for HF outputs
def minicmpv_trunc_hf_output(hf_output: RunnerOutput,
def minicpmv_trunc_hf_output(hf_output: RunnerOutput,
model: str) -> RunnerOutput:
output_ids, output_str, out_logprobs = hf_output
if output_str.endswith("<|eot_id|>"):
@ -197,6 +197,17 @@ def get_key_type_post_processor(
return process
def ignore_inputs_post_processor(
hf_inp_key: str) -> Callable[[BatchEncoding, str], BatchEncoding]:
"""Gets a handle to a post processor which ignores a given key."""
def process(hf_inputs: BatchEncoding, dtype: str):
del hf_inputs[hf_inp_key]
return hf_inputs
return process
def wrap_inputs_post_processor(hf_inputs: BatchEncoding, dtype: str):
return {"model_inputs": hf_inputs}

View File

@ -242,7 +242,7 @@ class FusedMoE(torch.nn.Module):
def _load_model_weight_or_group_weight_scale(self, shard_dim: int,
expert_data: torch.Tensor,
shard_id: str,
loaded_weight: torch.tensor,
loaded_weight: torch.Tensor,
tp_rank: int):
# Load grouped weight scales for group quantization
# or model weights
@ -261,7 +261,7 @@ class FusedMoE(torch.nn.Module):
def _load_per_channel_weight_scale(self, expert_data: torch.Tensor,
shard_dim: int, shard_id: str,
loaded_weight: torch.tensor,
loaded_weight: torch.Tensor,
tp_rank: int):
# for per channel weight quantization
if shard_id == "w2":
@ -274,7 +274,7 @@ class FusedMoE(torch.nn.Module):
tp_rank=tp_rank)
def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
shard_id: str, loaded_weight: torch.tensor, tp_rank: int):
shard_id: str, loaded_weight: torch.Tensor, tp_rank: int):
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
@ -292,7 +292,7 @@ class FusedMoE(torch.nn.Module):
expert_data.copy_(loaded_weight)
def _load_w2(self, expert_data: torch.Tensor, shard_dim: int,
shard_id: str, loaded_weight: torch.tensor, tp_rank: int):
shard_id: str, loaded_weight: torch.Tensor, tp_rank: int):
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
@ -311,7 +311,7 @@ class FusedMoE(torch.nn.Module):
param_data[expert_id] = loaded_weight
def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor,
shard_dim: int, loaded_weight: torch.tensor, tp_rank: int):
shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int):
if shard_id == "w2":
self._load_w2(shard_id=shard_id,

View File

@ -52,7 +52,7 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@ -378,6 +378,7 @@ class MiniCPMModel(nn.Module):
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self.num_experts = getattr(self.config, "num_experts", 0)
self._init_layers(prefix, config, cache_config, quant_config)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
@ -437,6 +438,73 @@ class MiniCPMModel(nn.Module):
hidden_states = self.norm(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
expert_params_mapping = [
# (param_name, weight_name, expert_id)
("ws" if weight_name in ["w1", "w3"] else "w2s",
f"experts.{expert_id}.{weight_name}.weight", expert_id)
for expert_id in range(self.num_experts)
for weight_name in ["w1", "w2", "w3"]
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for param_name, weight_name, expert_id in expert_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
weight_name,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
@ -480,8 +548,9 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.cache_config = cache_config
self.quant_config = quant_config
self.num_experts = getattr(self.config, "num_experts", 0)
self._init_model(vllm_config=vllm_config, prefix=prefix)
self.model = self._init_model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
unpadded_vocab_size = config.vocab_size
if lora_config:
unpadded_vocab_size += lora_config.lora_extra_vocab_size
@ -506,8 +575,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.model.make_empty_intermediate_tensors)
def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.model = MiniCPMModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
return MiniCPMModel(vllm_config=vllm_config, prefix=prefix)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
@ -546,72 +614,9 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
expert_params_mapping = [
# (param_name, weight_name, expert_id)
("ws" if weight_name in ["w1", "w3"] else "w2s",
f"experts.{expert_id}.{weight_name}.weight", expert_id)
for expert_id in range(self.num_experts)
for weight_name in ["w1", "w2", "w3"]
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for param_name, weight_name, expert_id in expert_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
weight_name,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)

View File

@ -40,7 +40,7 @@ from vllm.model_executor.models.minicpm import (MiniCPMDecoderLayer,
MiniCPMForCausalLM,
MiniCPMModel)
from .utils import make_layers, maybe_prefix
from .utils import make_layers
class MiniCPM3Attention(nn.Module):
@ -248,5 +248,4 @@ class MiniCPM3ForCausalLM(MiniCPMForCausalLM):
}
def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.model = MiniCPM3Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
return MiniCPM3Model(vllm_config=vllm_config, prefix=prefix)

View File

@ -22,7 +22,7 @@
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
import math
import re
from functools import partial
from functools import cached_property, partial
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
Set, Tuple, TypedDict, Union)
@ -37,19 +37,15 @@ from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2,
get_2d_sincos_pos_embed)
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.models.minicpm import MiniCPMModel
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.minicpm import MiniCPMForCausalLM
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.models.utils import LLMWrapper
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.image import cached_get_image_processor
@ -58,11 +54,7 @@ from vllm.sequence import IntermediateTensors, SequenceData
from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import is_pp_missing_parameter, maybe_prefix
_KEYS_TO_MODIFY_MAPPING = {
"llm.lm_head": "lm_head",
}
from .utils import AutoWeightsLoader, maybe_prefix
RawImageType = Union[Image.Image, torch.Tensor]
@ -297,10 +289,9 @@ def input_processor_for_minicpmv(ctx: InputContext, inputs: DecoderOnlyInputs):
def get_placeholder(image_size: Tuple[int, int], num_image: int):
if version == (2, 0) or version == (2, 5):
return image_processor. \
get_slice_image_placeholder(image_size)
return image_processor. \
get_slice_image_placeholder(image_size, num_image)
return image_processor.get_slice_image_placeholder(image_size)
return image_processor.get_slice_image_placeholder(
image_size, num_image)
prompt = inputs.get("prompt")
token_ids = inputs.get("prompt_token_ids")
@ -400,37 +391,32 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
self.vpm = self.init_vision_module(config,
quant_config,
prefix=maybe_prefix(prefix, "vpm"))
param_dtype = torch.get_default_dtype()
self.vpm.to(dtype=param_dtype)
self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else
self.vpm.embeddings.embed_dim)
self.embed_dim = self.config.hidden_size
self.resampler = self.init_resampler(self.embed_dim,
self.vision_dim,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "resampler"))
self.resampler.to(device="cuda", dtype=param_dtype)
# TODO: why is there _KEYS_TO_MODIFY_MAPPING? lm_head should be in llm
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "llm.lm_head"))
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.llm.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.llm, "sampler"):
return self.llm.sampler
return get_sampler()
def get_embedding(
self,
input_ids: torch.Tensor,
image_inputs: Optional[MiniCPMVImageInputs],
) -> Tuple[torch.Tensor, torch.Tensor]:
vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids)
if hasattr(self.config, "scale_emb"):
vlm_embedding *= self.config.scale_emb
vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
if image_inputs is None: # No image
vision_hidden_states = torch.tensor([], device=input_ids.device)
@ -575,7 +561,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
# for `torch.compile` integration
input_ids = None
output = self.llm(
output = self.llm.model(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
@ -590,9 +576,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
return self.llm.compute_logits(hidden_states, sampling_metadata)
def sample(
self,
@ -604,52 +588,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
use_default_weight_loading = False
if self.is_default_weight_loading(name):
use_default_weight_loading = True
else:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
use_default_weight_loading = True
if use_default_weight_loading:
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
def get_mm_mapping(self) -> MultiModelKeys:
"""
@ -693,9 +633,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
data: MiniCPMVImageInputs) -> torch.Tensor:
raise NotImplementedError
def is_default_weight_loading(self, name: str) -> bool:
raise NotImplementedError
class MiniCPMV2_0(MiniCPMVBaseModel):
@ -708,8 +645,7 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
vllm_config: VllmConfig,
prefix: str = "",
) -> nn.Module:
return LLMWrapper(MiniCPMModel(vllm_config=vllm_config, prefix=prefix),
name="model")
return MiniCPMForCausalLM(vllm_config=vllm_config, prefix=prefix)
def init_vision_module(
self,
@ -717,11 +653,12 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> nn.Module:
# TODO :refactor this vision model
# TODO: refactor this vision model
try:
import timm
except ImportError:
raise ImportError("Please install timm==0.9.10") from ImportError
with set_default_torch_dtype(torch.float16):
model = timm.create_model(
"vit_so400m_patch14_siglip_384.webli",
@ -731,6 +668,8 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
dynamic_img_pad=True,
)
model = model.to(dtype=torch.get_default_dtype())
if (isinstance(model, timm.models.VisionTransformer)
and model.attn_pool is not None):
model.attn_pool = torch.nn.Identity()
@ -759,7 +698,7 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
quant_config=quant_config,
prefix=prefix)
return resampler
return resampler.to(device="cuda", dtype=torch.get_default_dtype())
def get_vision_embedding(
self,
@ -790,9 +729,6 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
return self.get_vision_embedding(pixel_values)
def is_default_weight_loading(self, name: str) -> bool:
return "resampler" in name or "vpm" in name
class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
packed_modules_mapping = {
@ -843,8 +779,7 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
vllm_config: VllmConfig,
prefix: str = "",
) -> nn.Module:
return LLMWrapper(LlamaModel(vllm_config=vllm_config, prefix=prefix),
name="model")
return LlamaForCausalLM(vllm_config=vllm_config, prefix=prefix)
def init_vision_module(
self,
@ -871,7 +806,8 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
kv_dim=vision_dim,
quant_config=quant_config,
prefix=prefix)
return resampler
return resampler.to(device="cuda", dtype=torch.get_default_dtype())
def get_vision_embedding(
self,
@ -913,9 +849,6 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
return self.get_vision_embedding(all_pixel_values.type(dtype),
patch_attn_mask, tgt_sizes)
def is_default_weight_loading(self, name: str) -> bool:
return "resampler" in name
class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
packed_modules_mapping = {
@ -966,8 +899,7 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
vllm_config: VllmConfig,
prefix: str = "",
) -> nn.Module:
return LLMWrapper(Qwen2Model(vllm_config=vllm_config, prefix=prefix),
name="model")
return Qwen2ForCausalLM(vllm_config=vllm_config, prefix=prefix)
def init_vision_module(
self,
@ -995,7 +927,8 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
kv_dim=vision_dim,
quant_config=quant_config,
prefix=prefix)
return resampler
return resampler.to(device="cuda", dtype=torch.get_default_dtype())
def get_vision_embedding(
self,
@ -1043,9 +976,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
return self.resampler(vision_embedding, tgt_sizes)
def is_default_weight_loading(self, name: str) -> bool:
return "resampler" in name
_SUPPORT_VERSION = {
(2, 0): MiniCPMV2_0,

View File

@ -1,7 +1,7 @@
import itertools
from dataclasses import dataclass, field
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
Optional, Protocol, Set, Tuple, Union, overload)
from typing import (Callable, Dict, Iterable, List, Literal, Mapping, Optional,
Protocol, Set, Tuple, Union, overload)
import torch
import torch.nn as nn
@ -560,30 +560,6 @@ def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):
return make_empty_intermediate_tensors
class LLMWrapper(nn.Module):
"""
To align with the key names of LoRA trained with PEFT, we need to add an
additional layer to the llm's implementation.
"""
def __init__(self, llm: nn.Module, name: str) -> None:
super().__init__()
self.model_name = name
setattr(self, name, llm)
def __getattr__(self, key: str):
llm = super().__getattr__(self.model_name)
if key == self.model_name:
return llm
return getattr(llm, key)
# We need to explicitly override this
def __call__(self, *args: Any, **kwargs: Any) -> Any:
llm = super().__getattr__(self.model_name)
return llm(*args, **kwargs)
def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
"""
Get the available attention backend for Vision Transformer.