[Model] Update support for NemotronNAS models (#15008)

Signed-off-by: Nave Assaf <nassaf@nvidia.com>
This commit is contained in:
Naveassaf 2025-03-31 15:35:14 +03:00 committed by GitHub
parent 555aa21905
commit 3aa2b6a637
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 524 additions and 133 deletions

View File

@ -224,7 +224,7 @@ See [this page](#generative-models) for more information on how to use generativ
* ✅︎
- * `DeciLMForCausalLM`
* DeciLM
* `Deci/DeciLM-7B`, `Deci/DeciLM-7B-instruct`, etc.
* `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, etc.
*
* ✅︎
- * `DeepseekForCausalLM`

View File

@ -112,7 +112,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Cohere2ForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r7b-12-2024", # noqa: E501
trust_remote_code=True),
"DbrxForCausalLM": _HfExamplesInfo("databricks/dbrx-instruct"),
"DeciLMForCausalLM": _HfExamplesInfo("Deci/DeciLM-7B-instruct",
"DeciLMForCausalLM": _HfExamplesInfo("nvidia/Llama-3_3-Nemotron-Super-49B-v1", # noqa: E501
trust_remote_code=True),
"DeepseekForCausalLM": _HfExamplesInfo("deepseek-ai/deepseek-llm-7b-chat"),
"DeepseekV2ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V2-Lite-Chat", # noqa: E501

View File

@ -411,6 +411,7 @@ class ModelConfig:
self.is_attention_free = self._init_attention_free()
self.is_hybrid = self._init_is_hybrid()
self.has_noops = self._init_has_noops()
self.has_inner_state = self._init_has_inner_state()
if current_platform.is_neuron():
@ -510,6 +511,10 @@ class ModelConfig:
def _init_is_hybrid(self) -> bool:
return self.registry.is_hybrid_model(self.architectures)
def _init_has_noops(self) -> bool:
architectures = getattr(self.hf_config, "architectures", [])
return self.registry.is_noops_model(architectures)
def _init_has_inner_state(self) -> bool:
return self.registry.model_has_inner_state(self.architectures)
@ -872,6 +877,14 @@ class ModelConfig:
return getattr(self.hf_config.attn_config, "kv_n_heads",
self.hf_config.num_attention_heads)
if self.hf_config.model_type == "nemotron-nas":
for block in self.hf_config.block_configs:
if not block.attention.no_op:
return self.hf_config.num_attention_heads \
// block.attention.n_heads_in_group
raise RuntimeError("Couldn't determine number of kv heads")
if self.is_attention_free:
return 0
@ -940,7 +953,9 @@ class ModelConfig:
# This function relies on 'layers_block_type' in hf_config,
# for w/o this attribute, we will need to have workarounds like so
attn_block_type = block_type == LayerBlockType.attention
is_transformer = not self.is_hybrid and not self.is_attention_free
is_transformer = not self.is_hybrid and \
not self.has_noops and \
not self.is_attention_free
start, end = self.get_layers_start_end_indices(parallel_config)
if is_transformer:
@ -951,6 +966,10 @@ class ModelConfig:
# Note that this code assumes there
# is only one type of attention-free block type.
return 0 if attn_block_type else end - start
elif self.has_noops:
block_configs = self.hf_config.block_configs
return sum(not bc.attention.no_op
for bc in block_configs[start:end])
else:
# Hybrid model
layers_block_type_value = getattr(self.hf_config,

View File

@ -1,124 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 DeciAI Research Team. All rights reserved.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on MistralAI GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only DeciLM model compatible with HuggingFace weights."""
from typing import Iterable, Set, Tuple
import torch
from vllm.config import VllmConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaForCausalLM
from .utils import is_pp_missing_parameter
class DeciLMForCausalLM(LlamaForCausalLM):
"""
Implementation for https://huggingface.co/Deci/DeciLM-7b-instruct.
Based on the llama executor.
The main difference is that DeciLM uses Variable Grouped Query Attention.
The constant number of GQA heads in the decoder is overridden with a value
per layer.
Usually, in the HuggingFace implementation, instead of
"config.num_key_value_heads", we use
"config.num_key_value_heads_per_layer[i]" which varies.
Currently, PagedAttention does not work well with variable GQA, so we
normalize the weights upon loading, and use uniform GQA with the max value
instead.
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
delattr(config, "num_key_value_heads_per_layer")
super().__init__(vllm_config=vllm_config)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if "k_proj" in name or "v_proj" in name:
loaded_weight = self._degroup_weight(loaded_weight)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
def _degroup_weight(self, loaded_weight: torch.Tensor) -> torch.Tensor:
hidden_size = self.config.hidden_size
head_size = self.config.hidden_size // self.config.num_attention_heads
target_num_kv_heads = self.config.num_key_value_heads
num_kv_heads = loaded_weight.shape[0] // head_size
n_repeats = target_num_kv_heads / num_kv_heads
assert n_repeats == int(n_repeats)
n_repeats = int(n_repeats)
loaded_weight = loaded_weight.view(num_kv_heads, head_size,
hidden_size)
loaded_weight = torch.repeat_interleave(loaded_weight,
repeats=n_repeats,
dim=0)
loaded_weight = loaded_weight.reshape(target_num_kv_heads * head_size,
hidden_size)
return loaded_weight

View File

@ -411,6 +411,35 @@ def is_hybrid(
return isinstance(model, IsHybrid)
@runtime_checkable
class HasNoOps(Protocol):
has_noops: ClassVar[Literal[True]] = True
@runtime_checkable
class _HasNoOpsType(Protocol):
has_noops: ClassVar[Literal[True]]
@overload
def has_noops(model: object) -> TypeIs[HasNoOps]:
...
@overload
def has_noops(model: Type[object]) -> TypeIs[Type[HasNoOps]]:
...
def has_noops(
model: Union[Type[object], object]
) -> Union[TypeIs[Type[HasNoOps]], TypeIs[HasNoOps]]:
if isinstance(model, type):
return isinstance(model, _HasNoOpsType)
return isinstance(model, HasNoOps)
@runtime_checkable
class SupportsCrossEncoding(Protocol):
"""The interface required for all models that support cross encoding."""

View File

@ -0,0 +1,454 @@
# SPDX-License-Identifier: Apache-2.0
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only deci model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Set, Tuple, Type, Union
import torch
from torch import nn
from transformers import LlamaConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.llama import LlamaAttention, LlamaMLP
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import HasNoOps, SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
def _ffn_mult_to_intermediate_size(ffn_mult: float, n_embd: int) -> int:
# DeciLM-specific code
intermediate_size = int(2 * ffn_mult * n_embd / 3)
return _find_multiple(intermediate_size, 256)
def _find_multiple(n: int, k: int) -> int:
# DeciLM-specific code
if n % k == 0:
return n
return n + k - (n % k)
class DeciLMDecoderLayer(nn.Module):
def __init__(
self,
config: LlamaConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
block_config = config.block_configs[layer_idx]
self._is_no_op_attention = block_config.attention.no_op
self._is_no_op_ffn = block_config.ffn.no_op
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
attention_bias = getattr(config, "attention_bias", False) or getattr(
config, "bias", False)
bias_o_proj = attention_bias
# support internlm/internlm3-8b with qkv_bias
if hasattr(config, "qkv_bias"):
attention_bias = config.qkv_bias
if not self._is_no_op_attention:
num_kv_heads = (config.num_attention_heads //
block_config.attention.n_heads_in_group)
self.self_attn = LlamaAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=num_kv_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
bias_o_proj=bias_o_proj,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
if not self._is_no_op_ffn:
ffn_mult = block_config.ffn.ffn_mult
intermediate_size = _ffn_mult_to_intermediate_size(
ffn_mult, config.hidden_size)
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
bias=getattr(config, "mlp_bias", False),
prefix=f"{prefix}.mlp",
)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if self._is_no_op_attention:
pass
else:
if (residual is None):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
# Fully Connected
if not self._is_no_op_ffn:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile
class DeciModel(nn.Module):
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: Type[DeciLMDecoderLayer] = DeciLMDecoderLayer,
):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.quant_config = quant_config
self.padding_idx = config.pad_token_id
lora_vocab = ((lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
if get_pp_group().is_first_rank or (config.tie_word_embeddings
and get_pp_group().is_last_rank):
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config,
)
else:
self.embed_tokens = PPMissingLayer()
def get_layer(prefix: str):
layer_idx = int(prefix.rsplit(".", 1)[1])
return layer_type(
config,
layer_idx,
cache_config,
quant_config=quant_config,
prefix=prefix,
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
get_layer,
prefix=f"{prefix}.layers",
)
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
kv_cache_index = 0
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
if not layer._is_no_op_attention:
hidden_states, residual = layer(positions, hidden_states,
residual)
kv_cache_index += 1
else:
hidden_states, residual = layer(positions, hidden_states,
residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name:
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class DeciLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, HasNoOps):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
"embed_tokens",
"lm_head",
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
# Mistral/Llama models can also be loaded with --load-format mistral
# from consolidated.safetensors checkpoints
mistral_mapping = {
"layers": "model.layers",
"attention": "self_attn",
"wq": "q_proj",
"wk": "k_proj",
"wv": "v_proj",
"wo": "o_proj",
"attention_norm": "input_layernorm",
"feed_forward": "mlp",
"w1": "gate_proj",
"w2": "down_proj",
"w3": "up_proj",
"ffn_norm": "post_attention_layernorm",
"tok_embeddings": "model.embed_tokens",
"output": "lm_head",
"norm": "model.norm",
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.model = self._init_model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=(
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else
lora_config.lora_vocab_padding_size),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(
self.model.embed_tokens)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
else:
self.lm_head = PPMissingLayer()
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
return DeciModel(vllm_config=vllm_config, prefix=prefix)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return model_output
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)

View File

@ -21,9 +21,10 @@ import torch.nn as nn
from vllm.logger import init_logger
from vllm.utils import is_in_doc_build
from .interfaces import (has_inner_state, is_attention_free, is_hybrid,
supports_cross_encoding, supports_multimodal,
supports_pp, supports_transcription, supports_v0_only)
from .interfaces import (has_inner_state, has_noops, is_attention_free,
is_hybrid, supports_cross_encoding,
supports_multimodal, supports_pp,
supports_transcription, supports_v0_only)
from .interfaces_base import is_text_generation_model
logger = init_logger(__name__)
@ -44,7 +45,7 @@ _TEXT_GENERATION_MODELS = {
"CohereForCausalLM": ("commandr", "CohereForCausalLM"),
"Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
"DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
"DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
@ -118,7 +119,7 @@ _EMBEDDING_MODELS = {
"RobertaModel": ("roberta", "RobertaEmbeddingModel"),
"RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"),
"XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
"GritLM": ("gritlm", "GritLM"),
@ -235,6 +236,7 @@ class _ModelInfo:
has_inner_state: bool
is_attention_free: bool
is_hybrid: bool
has_noops: bool
supports_transcription: bool
supports_v0_only: bool
@ -252,6 +254,7 @@ class _ModelInfo:
is_hybrid=is_hybrid(model),
supports_transcription=supports_transcription(model),
supports_v0_only=supports_v0_only(model),
has_noops=has_noops(model),
)
@ -511,6 +514,13 @@ class _ModelRegistry:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.is_hybrid
def is_noops_model(
self,
architectures: Union[str, List[str]],
) -> bool:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.has_noops
def is_transcription_model(
self,
architectures: Union[str, List[str]],

View File

@ -497,7 +497,10 @@ def set_cpu_offload_max_bytes(max_bytes: int) -> None:
def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
device = next(module.parameters()).device
if (params := next(module.parameters(), None)) is None:
return module
device = params.device
if device == torch.device("cpu"):
return module