mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-03 23:00:56 +08:00
[Misc] Add uninitialized params tracking for AutoWeightsLoader (#10327)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
d1557e66d3
commit
c4e464333e
@ -334,7 +334,17 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
with target_device:
|
||||
model = _initialize_model(vllm_config=vllm_config)
|
||||
|
||||
model.load_weights(self._get_all_weights(model_config, model))
|
||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||
loaded_weights = model.load_weights(
|
||||
self._get_all_weights(model_config, model))
|
||||
# We only enable strict check for non-quantiized models
|
||||
# that have loaded weights tracking currently.
|
||||
if model_config.quantization is None and loaded_weights is not None:
|
||||
weights_not_loaded = weights_to_load - loaded_weights
|
||||
if weights_not_loaded:
|
||||
raise ValueError(
|
||||
"Following weights were not initialized from "
|
||||
f"checkpoint: {weights_not_loaded}")
|
||||
|
||||
for _, module in model.named_modules():
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""Inference-only Snowflake Arctic model."""
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -480,7 +480,8 @@ class ArcticForCausalLM(nn.Module, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -518,6 +519,7 @@ class ArcticForCausalLM(nn.Module, SupportsPP):
|
||||
("ws", f"experts.{expert_id}.w3.weight", expert_id))
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
|
||||
logger.info(
|
||||
"It will take ~10 minutes loading from the 16-bit weights. "
|
||||
@ -573,3 +575,5 @@ class ArcticForCausalLM(nn.Module, SupportsPP):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -18,7 +18,7 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
|
||||
import math
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -404,13 +404,15 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
@ -449,6 +451,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -337,7 +337,8 @@ class BertModel(nn.Module):
|
||||
|
||||
return self.encoder(hidden_states, kv_caches, attn_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "query", "q"),
|
||||
@ -346,6 +347,7 @@ class BertModel(nn.Module):
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "pooler" in name:
|
||||
continue
|
||||
@ -368,6 +370,8 @@ class BertModel(nn.Module):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class BertEmbeddingModel(nn.Module):
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""Minimal implementation of BlipVisionModel intended to be only used
|
||||
within a vision language model."""
|
||||
from typing import Iterable, Optional, Tuple, Union
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -415,7 +415,8 @@ class BlipVisionModel(nn.Module):
|
||||
|
||||
return self.post_layernorm(hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -423,6 +424,7 @@ class BlipVisionModel(nn.Module):
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
] if self.shard_weight else []
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
layer_count = len(self.encoder.layers)
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
@ -440,8 +442,8 @@ class BlipVisionModel(nn.Module):
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
@ -450,3 +452,5 @@ class BlipVisionModel(nn.Module):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import torch
|
||||
@ -692,6 +692,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
) -> Optional[SamplerOutput]:
|
||||
return self.language_model.sample(logits, sampling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
loader.load_weights(weights)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
@ -16,7 +16,7 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only BLOOM model compatible with HuggingFace weights."""
|
||||
import math
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -341,8 +341,10 @@ class BloomForCausalLM(nn.Module, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if name == "lm_head.weight":
|
||||
continue
|
||||
@ -371,3 +373,5 @@ class BloomForCausalLM(nn.Module, SupportsPP):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from functools import cached_property
|
||||
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
|
||||
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
|
||||
Tuple, TypedDict, Union)
|
||||
|
||||
import torch
|
||||
@ -1034,7 +1034,8 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
@ -1044,6 +1045,7 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
@ -1111,3 +1113,5 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -3,7 +3,8 @@
|
||||
"""Inference-only ChatGLM model compatible with THUDM weights."""
|
||||
from argparse import Namespace
|
||||
from array import array
|
||||
from typing import Dict, Iterable, List, Mapping, Optional, Tuple, TypedDict
|
||||
from typing import (Dict, Iterable, List, Mapping, Optional, Set, Tuple,
|
||||
TypedDict)
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
@ -645,7 +646,8 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
# Merge two ColumnParallelLinear into one MergedColumnParallelLinear
|
||||
merged_weights_dict: Dict[str, Dict[str, Optional[torch.Tensor]]] = {
|
||||
"transformer.vision.linear_proj.merged_proj.weight": {
|
||||
@ -655,6 +657,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
|
||||
}
|
||||
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
is_weight_to_be_merge = False
|
||||
for _, merged_weight_dict in merged_weights_dict.items():
|
||||
@ -677,6 +680,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
|
||||
for combined_name, merged_weight_dict in merged_weights_dict.items():
|
||||
if combined_name in params_dict:
|
||||
@ -686,3 +690,5 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, combined_weight)
|
||||
loaded_params.add(combined_name)
|
||||
return loaded_params
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""Minimal implementation of CLIPVisionModel intended to be only used
|
||||
within a vision language model."""
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -483,7 +483,8 @@ class CLIPVisionModel(nn.Module):
|
||||
|
||||
# (TODO) Add prefix argument for filtering out weights to be loaded
|
||||
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -491,6 +492,7 @@ class CLIPVisionModel(nn.Module):
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
] if self.shard_weight else []
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
layer_count = len(self.vision_model.encoder.layers)
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
@ -508,8 +510,9 @@ class CLIPVisionModel(nn.Module):
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
@ -518,3 +521,5 @@ class CLIPVisionModel(nn.Module):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -402,7 +402,8 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -447,3 +448,4 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -417,13 +417,15 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
|
||||
expert_params_mapping = [(
|
||||
"w13_weight" if weight_name in ["w1", "v1"] else "w2_weight",
|
||||
f"mlp.{weight_name}",
|
||||
) for weight_name in ["w1", "v1", "w2"]]
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
for param_name, weight_name in expert_params_mapping:
|
||||
if weight_name not in name:
|
||||
@ -447,3 +449,5 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -22,7 +22,7 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only DeciLM model compatible with HuggingFace weights."""
|
||||
|
||||
from typing import Iterable, Tuple
|
||||
from typing import Iterable, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@ -57,7 +57,8 @@ class DeciLMForCausalLM(LlamaForCausalLM):
|
||||
delattr(config, "num_key_value_heads_per_layer")
|
||||
super().__init__(vllm_config=vllm_config)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -67,6 +68,7 @@ class DeciLMForCausalLM(LlamaForCausalLM):
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
@ -97,6 +99,8 @@ class DeciLMForCausalLM(LlamaForCausalLM):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
def _degroup_weight(self, loaded_weight: torch.Tensor) -> torch.Tensor:
|
||||
hidden_size = self.config.hidden_size
|
||||
|
||||
@ -20,7 +20,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Deepseek model."""
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -442,7 +442,8 @@ class DeepseekForCausalLM(nn.Module, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -453,6 +454,7 @@ class DeepseekForCausalLM(nn.Module, SupportsPP):
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
@ -487,3 +489,5 @@ class DeepseekForCausalLM(nn.Module, SupportsPP):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -20,7 +20,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only DeepseekV2 model."""
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -550,7 +550,8 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
|
||||
device=device),
|
||||
})
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
@ -566,6 +567,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
|
||||
num_experts=self.config.n_routed_experts)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
@ -623,3 +625,5 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -22,7 +22,7 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only Exaone model compatible with HuggingFace weights."""
|
||||
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -513,7 +513,8 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
@ -523,6 +524,7 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
(".gate_up_proj", ".c_fc_1", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
@ -543,6 +545,7 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
default_weight_loader)
|
||||
loaded_weight = loaded_weight[0]
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
@ -576,6 +579,8 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
# If this function is called, it should always initialize KV cache scale
|
||||
# factors (or else raise an exception). Thus, handled exceptions should
|
||||
|
||||
@ -18,7 +18,7 @@
|
||||
"""PyTorch Falcon model."""
|
||||
|
||||
import math
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -473,7 +473,8 @@ class FalconForCausalLM(nn.Module, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
total_num_heads = self.config.num_attention_heads
|
||||
if self.config.new_decoder_architecture:
|
||||
total_num_kv_heads = self.config.num_kv_heads
|
||||
@ -483,6 +484,7 @@ class FalconForCausalLM(nn.Module, SupportsPP):
|
||||
total_num_kv_heads = total_num_heads
|
||||
num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if name == "lm_head.weight" and self.tie_word_embeddings:
|
||||
# Falcon uses tied embeddings except Falcon-11b.
|
||||
@ -519,3 +521,5 @@ class FalconForCausalLM(nn.Module, SupportsPP):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import math
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -156,7 +156,8 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -165,12 +166,13 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
@ -183,6 +185,8 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Florence2ForConditionalGeneration(nn.Module):
|
||||
@ -248,10 +252,11 @@ class Florence2ForConditionalGeneration(nn.Module):
|
||||
) -> SamplerOutput:
|
||||
return self.language_model.sample(logits, sampling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
skip_prefixes = [
|
||||
'image_projection', "vision_tower", "image_proj_norm",
|
||||
"image_pos_embed", "visual_temporal_embed"
|
||||
]
|
||||
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
|
||||
loader.load_weights(weights)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
@ -16,7 +16,8 @@
|
||||
""" PyTorch Fuyu model."""
|
||||
import math
|
||||
from array import array
|
||||
from typing import Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -354,6 +355,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
next_tokens = self.language_model.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
loader.load_weights(weights)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
@ -424,7 +424,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -469,3 +470,4 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
logger.warning(
|
||||
"Some weights are not initialized from checkpoints: %s",
|
||||
unloaded_params)
|
||||
return loaded_params
|
||||
|
||||
@ -312,7 +312,8 @@ class Gemma2Model(nn.Module):
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -354,6 +355,7 @@ class Gemma2Model(nn.Module):
|
||||
logger.warning(
|
||||
"Some weights are not initialized from checkpoints: %s",
|
||||
unloaded_params)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
@ -451,13 +453,14 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."]
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
loader.load_weights(weights)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
|
||||
class Gemma2EmbeddingModel(nn.Module, SupportsPP):
|
||||
|
||||
@ -16,7 +16,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -298,8 +298,10 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "lm_head.weight" in name:
|
||||
# GPT-2 ties the weights of the embedding layer and the final
|
||||
@ -328,3 +330,5 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -323,8 +323,10 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "lm_head.weight" in name:
|
||||
continue
|
||||
@ -344,3 +346,5 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
weight_loader(param, loaded_weight, 'v')
|
||||
else:
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GPT-J model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -291,7 +291,8 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -301,6 +302,7 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "attn.bias" in name or "attn.masked_bias" in name:
|
||||
continue
|
||||
@ -330,3 +332,5 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -303,8 +303,10 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if ("attention.bias" in name or "attention.masked_bias" in name
|
||||
or "rotary_emb.inv_freq" in name):
|
||||
@ -337,3 +339,5 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -20,7 +20,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only IBM Granite model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -455,7 +455,8 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
device=device),
|
||||
})
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
@ -465,6 +466,7 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
@ -485,6 +487,7 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
default_weight_loader)
|
||||
loaded_weight = loaded_weight[0]
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
@ -518,6 +521,8 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
# If this function is called, it should always initialize KV cache scale
|
||||
# factors (or else raise an exception). Thus, handled exceptions should
|
||||
|
||||
@ -20,7 +20,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GraniteMoe model."""
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -419,7 +419,8 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
new_weights = {}
|
||||
for n, p in weights:
|
||||
if n.endswith('.block_sparse_moe.input_linear.weight'):
|
||||
@ -452,4 +453,5 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
pass
|
||||
else:
|
||||
new_weights[n] = p
|
||||
mixtral.MixtralForCausalLM.load_weights(self, new_weights.items())
|
||||
return mixtral.MixtralForCausalLM.load_weights(self,
|
||||
new_weights.items())
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch Idefics2 model."""
|
||||
|
||||
from typing import Iterable, Optional, Tuple
|
||||
from typing import Iterable, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -331,7 +331,8 @@ class Idefics2VisionTransformer(nn.Module):
|
||||
last_hidden_state = self.post_layernorm(encoder_outputs)
|
||||
return last_hidden_state
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -339,11 +340,13 @@ class Idefics2VisionTransformer(nn.Module):
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
@ -352,3 +355,5 @@ class Idefics2VisionTransformer(nn.Module):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
|
||||
import math
|
||||
from typing import (Dict, Iterable, List, Literal, Mapping, NamedTuple,
|
||||
Optional, Tuple, TypedDict, Union)
|
||||
Optional, Set, Tuple, TypedDict, Union)
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -751,9 +751,10 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
loader.load_weights(weights)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
from functools import partial
|
||||
from typing import Iterable, Optional, Tuple
|
||||
from typing import Iterable, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -469,10 +469,14 @@ class InternVisionModel(nn.Module):
|
||||
|
||||
return encoder_outputs
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -369,13 +369,15 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "w1", 0),
|
||||
("gate_up_proj", "w3", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
@ -402,3 +404,5 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -6,7 +6,7 @@
|
||||
# --------------------------------------------------------
|
||||
import re
|
||||
from functools import cached_property, partial
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import torch
|
||||
@ -663,6 +663,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
) -> Optional[SamplerOutput]:
|
||||
return self.language_model.sample(logits, sampling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
loader.load_weights(weights)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
@ -19,7 +19,7 @@
|
||||
"""Inference-only Jais model compatible with HuggingFace weights."""
|
||||
|
||||
import math
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -350,8 +350,10 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "lm_head.weight" in name:
|
||||
# GPT-2 ties the weights of the embedding layer and the final
|
||||
@ -382,3 +384,5 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""Inference-only Jamba model."""
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -462,7 +462,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -479,6 +480,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
num_experts=self.config.num_experts)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
@ -534,6 +536,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
def _is_moe_layer(name: str):
|
||||
|
||||
@ -20,7 +20,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -350,7 +350,8 @@ class LlamaModel(nn.Module):
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
@ -360,6 +361,7 @@ class LlamaModel(nn.Module):
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
@ -375,6 +377,7 @@ class LlamaModel(nn.Module):
|
||||
default_weight_loader)
|
||||
loaded_weight = loaded_weight[0]
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
@ -390,7 +393,6 @@ class LlamaModel(nn.Module):
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
@ -408,6 +410,8 @@ class LlamaModel(nn.Module):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
# If this function is called, it should always initialize KV cache scale
|
||||
# factors (or else raise an exception). Thus, handled exceptions should
|
||||
@ -577,13 +581,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."]
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
loader.load_weights(
|
||||
return loader.load_weights(
|
||||
self.maybe_remap_mistral(name, loaded_weight)
|
||||
for name, loaded_weight in weights)
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Protocol,
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set,
|
||||
Tuple, TypedDict, Union)
|
||||
|
||||
import torch
|
||||
@ -547,6 +547,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
) -> Optional[SamplerOutput]:
|
||||
return self.language_model.sample(logits, sampling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
loader.load_weights(weights)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import torch
|
||||
@ -654,6 +654,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
loader.load_weights(weights)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import math
|
||||
from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import numpy as np
|
||||
@ -445,10 +445,11 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
) -> Optional[SamplerOutput]:
|
||||
return self.language_model.sample(logits, sampling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
# This model doesn't support images for now
|
||||
ignore_unexpected_prefixes=["image_newline"],
|
||||
)
|
||||
loader.load_weights(weights)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import math
|
||||
from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import numpy as np
|
||||
@ -887,6 +887,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
) -> Optional[SamplerOutput]:
|
||||
return self.language_model.sample(logits, sampling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
loader.load_weights(weights)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""PyTorch MAMBA model."""
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -243,8 +243,10 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "A_log" in name:
|
||||
name = name.replace("A_log", "A")
|
||||
@ -256,3 +258,5 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -156,8 +156,10 @@ class Medusa(nn.Module):
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
|
||||
weights_map = {}
|
||||
|
||||
@ -181,9 +183,12 @@ class Medusa(nn.Module):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
|
||||
if self.token_map is not None:
|
||||
self.token_map.to(device=self.lm_heads[0].weight.device)
|
||||
|
||||
assert (self.truncated_vocab_size
|
||||
== self.orig_vocab_size) or (self.token_map is not None)
|
||||
|
||||
return loaded_params
|
||||
|
||||
@ -21,7 +21,7 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only MiniCPM model compatible with HuggingFace weights."""
|
||||
import math
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -539,7 +539,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -556,6 +557,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
for weight_name in ["w1", "w2", "w3"]
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
@ -606,3 +608,5 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -24,7 +24,7 @@ import math
|
||||
import re
|
||||
from functools import partial
|
||||
from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
|
||||
Tuple, TypedDict, Union)
|
||||
Set, Tuple, TypedDict, Union)
|
||||
|
||||
import torch
|
||||
import torch.types
|
||||
@ -602,7 +602,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -612,6 +613,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
|
||||
if key_to_modify in name:
|
||||
@ -630,10 +632,10 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
if is_pp_missing_parameter(
|
||||
name.replace(weight_name, param_name), self):
|
||||
name = name.replace(weight_name, param_name)
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
@ -646,6 +648,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
|
||||
@ -20,7 +20,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Mixtral model."""
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -404,7 +404,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -421,6 +422,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
num_experts=self.config.num_local_experts)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
@ -478,3 +480,5 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -20,7 +20,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Mixtral model."""
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -409,7 +409,8 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -418,6 +419,7 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
@ -448,3 +450,5 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch Mllama model."""
|
||||
import math
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import numpy as np
|
||||
@ -1427,7 +1427,8 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
|
||||
return outputs
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
@ -1437,7 +1438,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
updated_params = set()
|
||||
updated_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if 'patch_embedding.weight' in name:
|
||||
name = name.replace('patch_embedding.weight',
|
||||
@ -1457,6 +1458,8 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
updated_params.add(name)
|
||||
return updated_params
|
||||
|
||||
|
||||
def skip_attention_mask(sparse_mask: List[List[int]]) -> bool:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import math
|
||||
from typing import Iterable, List, Tuple
|
||||
from typing import Iterable, List, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -188,11 +188,15 @@ class MLPSpeculator(nn.Module):
|
||||
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
param = params_dict.get(name.replace("speculator.", ""))
|
||||
if param is not None:
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
|
||||
import math
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -324,8 +324,10 @@ class MPTForCausalLM(nn.Module, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
@ -336,3 +338,5 @@ class MPTForCausalLM(nn.Module, SupportsPP):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -20,7 +20,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Nemotron model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -474,7 +474,8 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
@ -482,6 +483,7 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
@ -522,3 +524,5 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -20,7 +20,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only OLMo model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -356,7 +356,8 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -366,6 +367,7 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
@ -402,3 +404,5 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -10,7 +10,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only OLMoE model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -364,7 +364,8 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -383,6 +384,7 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
|
||||
num_experts=self.config.num_experts)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
@ -455,3 +457,5 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -16,7 +16,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only OPT model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -394,7 +394,8 @@ class OPTForCausalLM(nn.Module, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -402,6 +403,7 @@ class OPTForCausalLM(nn.Module, SupportsPP):
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "lm_head.weight" in name and self.config.tie_word_embeddings:
|
||||
continue
|
||||
@ -431,3 +433,5 @@ class OPTForCausalLM(nn.Module, SupportsPP):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
# Copyright (c) OrionStar Inc.
|
||||
# LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE
|
||||
"""Inference-only Orion-14B model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -327,7 +327,8 @@ class OrionForCausalLM(nn.Module, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -337,6 +338,7 @@ class OrionForCausalLM(nn.Module, SupportsPP):
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
@ -368,3 +370,5 @@ class OrionForCausalLM(nn.Module, SupportsPP):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import torch
|
||||
@ -295,6 +295,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
) -> Optional[SamplerOutput]:
|
||||
return self.language_model.sample(logits, sampling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
loader.load_weights(weights)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
@ -19,7 +19,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only persimmon model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -324,8 +324,10 @@ class PersimmonForCausalLM(nn.Module, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
@ -358,3 +360,5 @@ class PersimmonForCausalLM(nn.Module, SupportsPP):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -34,7 +34,7 @@
|
||||
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
"""Inference-only Phi-1.5 model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -345,7 +345,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -353,6 +354,7 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
("qkv_proj", "v_proj", "v")
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
@ -383,3 +385,5 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import math
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -457,9 +457,11 @@ class Phi3SmallForCausalLM(nn.Module, SupportsPP):
|
||||
sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
@ -471,3 +473,5 @@ class Phi3SmallForCausalLM(nn.Module, SupportsPP):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
import itertools
|
||||
import re
|
||||
from functools import cached_property, lru_cache
|
||||
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
|
||||
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
|
||||
Tuple, TypedDict, Union)
|
||||
|
||||
import numpy as np
|
||||
@ -744,7 +744,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"model.vision_embed_tokens.wte": "embed_tokens",
|
||||
@ -759,5 +760,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
# The HF config doesn't specify whether these are tied,
|
||||
# so we detect it this way
|
||||
if "embed_tokens" not in autoloaded_weights:
|
||||
if "embed_tokens.weight" not in autoloaded_weights:
|
||||
self.embed_tokens = self.language_model.model.embed_tokens
|
||||
autoloaded_weights.add("embed_tokens.weight")
|
||||
return autoloaded_weights
|
||||
|
||||
@ -20,7 +20,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only PhiMoE model."""
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -598,7 +598,8 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -613,6 +614,7 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
num_experts=self.config.num_local_experts)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
@ -666,3 +668,5 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from dataclasses import dataclass, fields
|
||||
from functools import cached_property
|
||||
from itertools import tee
|
||||
from typing import Iterable, List, Mapping, Optional, Tuple, Union
|
||||
from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
@ -1053,7 +1053,8 @@ class PixtralHFVisionModel(nn.Module):
|
||||
|
||||
# (TODO) Add prefix argument for filtering out weights to be loaded
|
||||
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
@ -1063,6 +1064,7 @@ class PixtralHFVisionModel(nn.Module):
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
layer_count = len(self.transformer.layers)
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
@ -1075,8 +1077,8 @@ class PixtralHFVisionModel(nn.Module):
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
@ -1085,3 +1087,5 @@ class PixtralHFVisionModel(nn.Module):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -8,7 +8,7 @@ import math
|
||||
import re
|
||||
from functools import partial
|
||||
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
|
||||
Optional, Tuple, TypedDict, Union)
|
||||
Optional, Set, Tuple, TypedDict, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -964,13 +964,15 @@ class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "w2", 0),
|
||||
("gate_up_proj", "w1", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
@ -999,6 +1001,8 @@ class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class QWenLLM(QWenBaseModel):
|
||||
|
||||
@ -21,7 +21,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -332,7 +332,8 @@ class Qwen2Model(nn.Module):
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -342,6 +343,7 @@ class Qwen2Model(nn.Module):
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
@ -372,6 +374,8 @@ class Qwen2Model(nn.Module):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
@ -494,13 +498,14 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."]
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
loader.load_weights(weights)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
|
||||
class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
@ -564,7 +569,8 @@ class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
loader = AutoWeightsLoader(self,
|
||||
ignore_unexpected_prefixes=["lm_head."])
|
||||
loader.load_weights(weights)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
@ -20,7 +20,8 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
|
||||
from functools import lru_cache
|
||||
from typing import Iterable, List, Mapping, Optional, Tuple, TypedDict, Union
|
||||
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
|
||||
Union)
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
@ -420,7 +421,8 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -430,6 +432,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
@ -463,3 +466,5 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
# Copyright 2024 The Qwen team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
"""Inference-only Qwen2-Classification model compatible with HF weights."""
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -97,7 +97,8 @@ class Qwen2ForSequenceClassification(nn.Module, SupportsLoRA, SupportsPP):
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
loader = AutoWeightsLoader(self,
|
||||
ignore_unexpected_prefixes=["lm_head."])
|
||||
loader.load_weights(weights)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
@ -21,7 +21,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -436,7 +436,8 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -455,6 +456,7 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
|
||||
num_experts=self.config.num_experts)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
@ -532,3 +534,5 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
# Copyright 2024 The Qwen team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
"""Inference-only Qwen2-RM model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -110,7 +110,8 @@ class Qwen2ForRewardModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
loader = AutoWeightsLoader(self,
|
||||
ignore_unexpected_prefixes=["lm_head."])
|
||||
loader.load_weights(weights)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
@ -23,7 +23,7 @@
|
||||
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
|
||||
from functools import partial
|
||||
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
|
||||
Optional, Tuple, Type, TypedDict, Union)
|
||||
Optional, Set, Tuple, Type, TypedDict, Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -1333,7 +1333,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -1343,6 +1344,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
]
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
@ -1392,3 +1394,5 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
within a vision language model."""
|
||||
|
||||
import math
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -594,7 +594,8 @@ class SiglipVisionModel(nn.Module):
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -602,6 +603,7 @@ class SiglipVisionModel(nn.Module):
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
] if self.shard_weight else []
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
layer_count = len(self.vision_model.encoder.layers)
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
@ -619,8 +621,9 @@ class SiglipVisionModel(nn.Module):
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
@ -629,3 +632,5 @@ class SiglipVisionModel(nn.Module):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -21,7 +21,7 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only Solar model compatible with HuggingFace weights."""
|
||||
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -477,7 +477,8 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
@ -487,6 +488,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
@ -502,6 +504,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
default_weight_loader)
|
||||
loaded_weight = loaded_weight[0]
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
@ -535,6 +538,8 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
# If this function is called, it should always initialize KV cache scale
|
||||
# factors (or else raise an exception). Thus, handled exceptions should
|
||||
|
||||
@ -18,7 +18,7 @@
|
||||
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
|
||||
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
|
||||
model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -306,7 +306,8 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -316,6 +317,7 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
@ -347,3 +349,5 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch Starcoder2 model."""
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -314,7 +314,8 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
@ -323,6 +324,7 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
@ -346,3 +348,5 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
|
||||
import math
|
||||
from functools import cached_property, lru_cache
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union, cast)
|
||||
|
||||
import numpy as np
|
||||
@ -504,10 +504,11 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
) -> Optional[SamplerOutput]:
|
||||
return self.language_model.sample(logits, sampling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."})
|
||||
|
||||
loader = AutoWeightsLoader(self,
|
||||
ignore_unexpected_prefixes=["audio_tower."])
|
||||
loader.load_weights(weights, mapper=hf_to_vllm_mapper)
|
||||
return loader.load_weights(weights, mapper=hf_to_vllm_mapper)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import itertools
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
|
||||
Optional, Protocol, Tuple, Union, overload)
|
||||
Optional, Protocol, Set, Tuple, Union, overload)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -172,8 +172,9 @@ class AutoWeightsLoader:
|
||||
if module != self.module:
|
||||
module_load_weights = getattr(module, "load_weights", None)
|
||||
if callable(module_load_weights):
|
||||
module_load_weights(weights)
|
||||
return
|
||||
loaded_params = module_load_weights(weights)
|
||||
yield from map(lambda x: self._get_qualname(base_prefix, x),
|
||||
loaded_params)
|
||||
|
||||
child_modules = dict(module.named_children())
|
||||
child_params = dict(module.named_parameters(recurse=False))
|
||||
@ -222,11 +223,11 @@ class AutoWeightsLoader:
|
||||
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
*,
|
||||
mapper: Optional[WeightsMapper] = None,
|
||||
) -> List[str]:
|
||||
) -> Set[str]:
|
||||
if mapper is not None:
|
||||
weights = mapper.apply(weights)
|
||||
|
||||
autoloaded_weights = list(self._load_module("", self.module, weights))
|
||||
autoloaded_weights = set(self._load_module("", self.module, weights))
|
||||
return autoloaded_weights
|
||||
|
||||
|
||||
|
||||
@ -19,7 +19,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Xverse model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -376,7 +376,8 @@ class XverseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
@ -385,6 +386,7 @@ class XverseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if ("rotary_emb.inv_freq" in name
|
||||
or "rotary_emb.cos_cached" in name
|
||||
@ -413,3 +415,5 @@ class XverseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user