mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-27 17:21:21 +08:00
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
parent
359d293006
commit
77d906995c
@ -26,6 +26,7 @@ from vllm.attention import Attention
|
|||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY,
|
from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY,
|
||||||
GeluAndMul,
|
GeluAndMul,
|
||||||
@ -44,6 +45,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
from vllm.model_executor.model_loader.weight_utils import (
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
default_weight_loader, maybe_remap_kv_scale_name)
|
default_weight_loader, maybe_remap_kv_scale_name)
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata
|
||||||
|
|
||||||
from .interfaces import SupportsQuant
|
from .interfaces import SupportsQuant
|
||||||
from .utils import (AutoWeightsLoader, extract_layer_index,
|
from .utils import (AutoWeightsLoader, extract_layer_index,
|
||||||
@ -51,6 +53,8 @@ from .utils import (AutoWeightsLoader, extract_layer_index,
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
EPS = torch.tensor(torch.finfo().min)
|
||||||
|
|
||||||
|
|
||||||
class Gemma3nAltUp(nn.Module):
|
class Gemma3nAltUp(nn.Module):
|
||||||
"""Alternating updates (Altup)
|
"""Alternating updates (Altup)
|
||||||
@ -532,16 +536,29 @@ class Gemma3nDecoderLayer(nn.Module):
|
|||||||
return corrected_predictions
|
return corrected_predictions
|
||||||
|
|
||||||
|
|
||||||
@support_torch_compile
|
# This enables torch.compile if --kv-sharing-fast-prefill passed
|
||||||
class Gemma3nTextModel(nn.Module, SupportsQuant):
|
@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config.
|
||||||
|
kv_sharing_fast_prefill)
|
||||||
|
class Gemma3nSelfDecoder(nn.Module):
|
||||||
|
"""
|
||||||
|
Includes altup embedding and self decoder layers
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = "",
|
||||||
|
decoder_layers: list[Gemma3nDecoderLayer],
|
||||||
|
layer_idx_start: int,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.decoder_layers = decoder_layers
|
||||||
|
self.layer_idx_start = layer_idx_start
|
||||||
|
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
cache_config = vllm_config.cache_config
|
|
||||||
quant_config = vllm_config.quant_config
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
quant_config = vllm_config.quant_config
|
||||||
|
|
||||||
self.embed_tokens = VocabParallelEmbedding(
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
@ -594,32 +611,6 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
|
|||||||
prefix=f"{prefix}.altup_projections.{idx-1}",
|
prefix=f"{prefix}.altup_projections.{idx-1}",
|
||||||
) for idx in range(1, self.config.altup_num_inputs)
|
) for idx in range(1, self.config.altup_num_inputs)
|
||||||
])
|
])
|
||||||
self.altup_unembed_projections = nn.ModuleList([
|
|
||||||
ColumnParallelLinear(
|
|
||||||
config.hidden_size,
|
|
||||||
config.hidden_size,
|
|
||||||
bias=False,
|
|
||||||
gather_output=True,
|
|
||||||
return_bias=False,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=f"{prefix}.altup_unembed_projections.{idx-1}",
|
|
||||||
) for idx in range(1, self.config.altup_num_inputs)
|
|
||||||
])
|
|
||||||
|
|
||||||
# Transformer blocks.
|
|
||||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
|
||||||
config.num_hidden_layers,
|
|
||||||
lambda prefix: Gemma3nDecoderLayer(
|
|
||||||
config, cache_config, quant_config, prefix=prefix),
|
|
||||||
prefix=f"{prefix}.layers")
|
|
||||||
self.norm = RMSNorm(
|
|
||||||
config.hidden_size,
|
|
||||||
eps=config.rms_norm_eps,
|
|
||||||
)
|
|
||||||
self.eps = torch.tensor(torch.finfo().min)
|
|
||||||
|
|
||||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
||||||
return self.embed_tokens(input_ids) * self.embed_scale
|
|
||||||
|
|
||||||
def get_per_layer_input_embeddings(
|
def get_per_layer_input_embeddings(
|
||||||
self, input_ids: torch.Tensor) -> torch.Tensor:
|
self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
@ -633,20 +624,11 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
|
|||||||
return self.embed_tokens_per_layer(
|
return self.embed_tokens_per_layer(
|
||||||
per_layer_inputs_tokens) * self.embed_scale_per_layer
|
per_layer_inputs_tokens) * self.embed_scale_per_layer
|
||||||
|
|
||||||
def forward(
|
def get_per_layer_inputs(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.Tensor],
|
hidden_states_0: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
per_layer_inputs: Optional[torch.Tensor],
|
||||||
per_layer_inputs: torch.Tensor,
|
) -> torch.Tensor:
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
||||||
if inputs_embeds is not None:
|
|
||||||
hidden_states_0 = inputs_embeds
|
|
||||||
else:
|
|
||||||
hidden_states_0 = self.get_input_embeddings(input_ids)
|
|
||||||
|
|
||||||
per_layer_projection = self.per_layer_model_projection(hidden_states_0)
|
per_layer_projection = self.per_layer_model_projection(hidden_states_0)
|
||||||
per_layer_projection = per_layer_projection.reshape(
|
per_layer_projection = per_layer_projection.reshape(
|
||||||
*hidden_states_0.shape[:-1],
|
*hidden_states_0.shape[:-1],
|
||||||
@ -655,14 +637,18 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
|
|||||||
)
|
)
|
||||||
per_layer_projection = self.per_layer_projection_norm(
|
per_layer_projection = self.per_layer_projection_norm(
|
||||||
per_layer_projection)
|
per_layer_projection)
|
||||||
|
|
||||||
if per_layer_inputs is not None:
|
if per_layer_inputs is not None:
|
||||||
# Profiling run does not compute per_layer_inputs
|
# Profiling run does not compute per_layer_inputs
|
||||||
per_layer_inputs = per_layer_projection + per_layer_inputs
|
per_layer_inputs = per_layer_projection + per_layer_inputs
|
||||||
per_layer_inputs *= self.per_layer_input_scale
|
per_layer_inputs *= self.per_layer_input_scale
|
||||||
else:
|
else:
|
||||||
per_layer_inputs = per_layer_projection
|
per_layer_inputs = per_layer_projection
|
||||||
|
return per_layer_inputs
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.embed_tokens(input_ids) * self.embed_scale
|
||||||
|
|
||||||
|
def altup_embed(self, hidden_states_0: torch.Tensor) -> torch.Tensor:
|
||||||
# Altup embed.
|
# Altup embed.
|
||||||
hidden_states = [hidden_states_0] * self.config.altup_num_inputs
|
hidden_states = [hidden_states_0] * self.config.altup_num_inputs
|
||||||
target_magnitude = torch.mean(hidden_states_0**2, dim=-1,
|
target_magnitude = torch.mean(hidden_states_0**2, dim=-1,
|
||||||
@ -673,11 +659,77 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
|
|||||||
dim=-1,
|
dim=-1,
|
||||||
keepdim=True)**0.5
|
keepdim=True)**0.5
|
||||||
hidden_states[i] *= target_magnitude / torch.maximum(
|
hidden_states[i] *= target_magnitude / torch.maximum(
|
||||||
new_magnitude, self.eps)
|
new_magnitude, EPS)
|
||||||
hidden_states = torch.stack(hidden_states, dim=0)
|
hidden_states = torch.stack(hidden_states, dim=-1)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
# Transformer blocks.
|
def forward(
|
||||||
for layer_idx, layer in enumerate(self.layers):
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
per_layer_inputs: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if inputs_embeds is not None:
|
||||||
|
hidden_states_0 = inputs_embeds
|
||||||
|
else:
|
||||||
|
hidden_states_0 = self.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
|
adjusted_per_layer_inputs = self.get_per_layer_inputs(
|
||||||
|
hidden_states_0, per_layer_inputs)
|
||||||
|
hidden_states = self.altup_embed(hidden_states_0)
|
||||||
|
|
||||||
|
# [altnum_inputs, num_tokens, hidden_size]
|
||||||
|
hidden_states = hidden_states.permute(2, 0, 1)
|
||||||
|
|
||||||
|
for idx, layer in enumerate(self.decoder_layers):
|
||||||
|
layer_idx = idx + self.layer_idx_start
|
||||||
|
# [altup_num_inputs, num_tokens, hidden_size]
|
||||||
|
hidden_states = layer(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
per_layer_input=adjusted_per_layer_inputs[:, layer_idx, :],
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# [num_tokens, hidden_size, altnum_inputs]
|
||||||
|
hidden_states = hidden_states.permute(1, 2, 0)
|
||||||
|
|
||||||
|
return hidden_states, adjusted_per_layer_inputs
|
||||||
|
|
||||||
|
|
||||||
|
# This enables torch.compile if --kv-sharing-fast-prefill passed
|
||||||
|
@support_torch_compile(enable_if=lambda vllm_config: vllm_config.cache_config.
|
||||||
|
kv_sharing_fast_prefill)
|
||||||
|
class Gemma3nCrossDecoder(nn.Module):
|
||||||
|
"""
|
||||||
|
Cross-decoder layers
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = "",
|
||||||
|
decoder_layers: list[Gemma3nDecoderLayer],
|
||||||
|
layer_idx_start: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.decoder_layers = decoder_layers
|
||||||
|
self.layer_idx_start = layer_idx_start
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
per_layer_inputs: torch.Tensor,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# [altnum_inputs, num_tokens, hidden_size]
|
||||||
|
hidden_states = hidden_states.permute(2, 0, 1)
|
||||||
|
for idx, layer in enumerate(self.decoder_layers):
|
||||||
|
layer_idx = idx + self.layer_idx_start
|
||||||
# [altup_num_inputs, num_tokens, hidden_size]
|
# [altup_num_inputs, num_tokens, hidden_size]
|
||||||
hidden_states = layer(
|
hidden_states = layer(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
@ -685,22 +737,249 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
|
|||||||
per_layer_input=per_layer_inputs[:, layer_idx, :],
|
per_layer_input=per_layer_inputs[:, layer_idx, :],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
# [num_tokens, hidden_size, altnum_inputs]
|
||||||
|
hidden_states = hidden_states.permute(1, 2, 0)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
# This disables torch.compile if --kv-sharing-fast-prefill passed
|
||||||
|
@support_torch_compile(enable_if=lambda vllm_config: not vllm_config.
|
||||||
|
cache_config.kv_sharing_fast_prefill)
|
||||||
|
class Gemma3nTextModel(nn.Module, SupportsQuant):
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
cache_config = vllm_config.cache_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
self.config = config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
self.altup_unembed_projections = nn.ModuleList([
|
||||||
|
ColumnParallelLinear(
|
||||||
|
config.hidden_size,
|
||||||
|
config.hidden_size,
|
||||||
|
bias=False,
|
||||||
|
gather_output=True,
|
||||||
|
return_bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.altup_unembed_projections.{idx-1}",
|
||||||
|
) for idx in range(1, self.config.altup_num_inputs)
|
||||||
|
])
|
||||||
|
|
||||||
|
# Allocate config.num_kv_shared_layers layers for self-decoder
|
||||||
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
|
config.num_hidden_layers,
|
||||||
|
lambda prefix: Gemma3nDecoderLayer(
|
||||||
|
config, cache_config, quant_config, prefix=prefix),
|
||||||
|
prefix=f"{prefix}.layers")
|
||||||
|
|
||||||
|
first_kv_shared_layer_idx = (config.num_hidden_layers -
|
||||||
|
config.num_kv_shared_layers)
|
||||||
|
|
||||||
|
# NOTE(sarckk): importing this top level seems to cause issues
|
||||||
|
# during running of tests.
|
||||||
|
from vllm.compilation.backends import set_model_tag
|
||||||
|
|
||||||
|
# Layer idx 0-19 are self-decoder layers in You Only Cache Once (YOCO)
|
||||||
|
with set_model_tag("self_decoder"):
|
||||||
|
self.self_decoder = Gemma3nSelfDecoder(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
prefix=f"{prefix}.self_decoder",
|
||||||
|
decoder_layers=self.layers[:first_kv_shared_layer_idx],
|
||||||
|
layer_idx_start=0,
|
||||||
|
)
|
||||||
|
# Layer idx 20-30 are cross-decoder layers in YOCO
|
||||||
|
with set_model_tag("cross_decoder"):
|
||||||
|
self.cross_decoder = Gemma3nCrossDecoder(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
prefix=f"{prefix}.cross_decoder",
|
||||||
|
decoder_layers=self.layers[first_kv_shared_layer_idx:],
|
||||||
|
layer_idx_start=first_kv_shared_layer_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm = RMSNorm(
|
||||||
|
config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.fast_prefill_enabled = cache_config.kv_sharing_fast_prefill
|
||||||
|
|
||||||
|
if self.fast_prefill_enabled:
|
||||||
|
# Allocate static buffers for CUDAGraph
|
||||||
|
# TODO(sarckk): Extract this functionality to interface
|
||||||
|
max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||||
|
device = next(self.parameters()).device
|
||||||
|
self.positions = torch.zeros(max_num_tokens,
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device)
|
||||||
|
self.hidden_states = torch.zeros(
|
||||||
|
(max_num_tokens, config.hidden_size,
|
||||||
|
self.config.altup_num_inputs),
|
||||||
|
dtype=self.embed_tokens.weight.dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.per_layer_inputs = torch.zeros(
|
||||||
|
(max_num_tokens, self.config.num_hidden_layers,
|
||||||
|
self.config.hidden_size_per_layer_input),
|
||||||
|
dtype=self.embed_tokens.weight.dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embed_tokens(self):
|
||||||
|
return self.self_decoder.embed_tokens
|
||||||
|
|
||||||
|
def get_per_layer_input_embeddings(
|
||||||
|
self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.self_decoder.get_per_layer_input_embeddings(input_ids)
|
||||||
|
|
||||||
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.self_decoder.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
|
def fast_prefill_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
per_layer_inputs: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
logits_indices_padded, num_logits_indices = None, None
|
||||||
|
attn_metadata = get_forward_context().attn_metadata
|
||||||
|
|
||||||
|
# attn_metadata is None during dummy runs
|
||||||
|
if (self.fast_prefill_enabled and attn_metadata is not None):
|
||||||
|
assert isinstance(attn_metadata, dict)
|
||||||
|
# Last layer is a KV sharing layer
|
||||||
|
layer_attn_metadata = attn_metadata[
|
||||||
|
self.layers[-1].self_attn.attn.layer_name]
|
||||||
|
if (isinstance(layer_attn_metadata, KVSharingFastPrefillMetadata)):
|
||||||
|
logits_indices_padded = (
|
||||||
|
layer_attn_metadata.logits_indices_padded)
|
||||||
|
num_logits_indices = layer_attn_metadata.num_logits_indices
|
||||||
|
|
||||||
|
# Copy inputs for cudagraph
|
||||||
|
batch_size = positions.size(0)
|
||||||
|
self.positions[:batch_size].copy_(positions)
|
||||||
|
self_decoder_hidden_states, per_layer_inputs_adjusted = \
|
||||||
|
self.self_decoder(
|
||||||
|
input_ids=input_ids,
|
||||||
|
positions=self.positions[:batch_size],
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
per_layer_inputs=per_layer_inputs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if logits_indices_padded is None:
|
||||||
|
logits_indices_padded = torch.arange(
|
||||||
|
positions.size(0),
|
||||||
|
dtype=positions.dtype,
|
||||||
|
device=positions.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE(sarckk): There is currently a bug caused by
|
||||||
|
# vLLM converting output of last piecewise CUDA graph
|
||||||
|
# to weakref, causing memory to be prematurely freed
|
||||||
|
# when there are multiple compilation units
|
||||||
|
# Keep .clone() until fix in
|
||||||
|
# https://github.com/vllm-project/vllm/pull/22282
|
||||||
|
hidden_states = self_decoder_hidden_states.clone()
|
||||||
|
|
||||||
|
# Copy inputs for cudagraph
|
||||||
|
num_padded_logits_indices = logits_indices_padded.size(0)
|
||||||
|
self.positions[:num_padded_logits_indices].copy_(
|
||||||
|
positions[logits_indices_padded])
|
||||||
|
self.hidden_states[:num_padded_logits_indices].copy_(
|
||||||
|
self_decoder_hidden_states[logits_indices_padded])
|
||||||
|
self.per_layer_inputs[:num_padded_logits_indices].copy_(
|
||||||
|
per_layer_inputs_adjusted[logits_indices_padded])
|
||||||
|
cross_decoder_hidden_states = self.cross_decoder(
|
||||||
|
positions=self.positions[:num_padded_logits_indices],
|
||||||
|
hidden_states=self.hidden_states[:num_padded_logits_indices],
|
||||||
|
per_layer_inputs=self.per_layer_inputs[:num_padded_logits_indices],
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if num_logits_indices is not None:
|
||||||
|
assert num_logits_indices > 0
|
||||||
|
# Merge cross-decoder and self-decoder hidden states
|
||||||
|
hidden_states[logits_indices_padded[:num_logits_indices]] = (
|
||||||
|
cross_decoder_hidden_states[:num_logits_indices])
|
||||||
|
else:
|
||||||
|
hidden_states = cross_decoder_hidden_states
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def normal_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
per_layer_inputs: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
hidden_states, per_layer_inputs = self.self_decoder(
|
||||||
|
input_ids=input_ids,
|
||||||
|
positions=positions,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
per_layer_inputs=per_layer_inputs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = self.cross_decoder(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
per_layer_inputs=per_layer_inputs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def altup_unembed(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
# Altup unembed.
|
# Altup unembed.
|
||||||
target_magnitude = torch.mean(hidden_states[0]**2,
|
target_magnitude = torch.mean(hidden_states[..., 0]**2,
|
||||||
dim=-1,
|
dim=-1,
|
||||||
keepdim=True)**0.5
|
keepdim=True)**0.5
|
||||||
for i in range(1, self.config.altup_num_inputs):
|
for i in range(1, self.config.altup_num_inputs):
|
||||||
hidden_states[i] = self.altup_unembed_projections[i - 1](
|
hidden_states[..., i] = self.altup_unembed_projections[i - 1](
|
||||||
hidden_states[i])
|
hidden_states[..., i])
|
||||||
new_magnitude = torch.mean(hidden_states[i]**2,
|
new_magnitude = torch.mean(hidden_states[..., i]**2,
|
||||||
dim=-1,
|
dim=-1,
|
||||||
keepdim=True)**0.5
|
keepdim=True)**0.5
|
||||||
hidden_states[i] *= target_magnitude / torch.maximum(
|
hidden_states[..., i] *= target_magnitude / torch.maximum(
|
||||||
new_magnitude, self.eps)
|
new_magnitude, EPS)
|
||||||
# [altup_num_inputs,num_tokens,hidden_size] -> [num_tokens,hidden_size]
|
# [num_tokens,hidden_size, altup_num_inputs] -> [num_tokens,hidden_size]
|
||||||
hidden_states = torch.mean(hidden_states, dim=0)
|
hidden_states = torch.mean(hidden_states, dim=-1)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.Tensor],
|
||||||
|
positions: torch.Tensor,
|
||||||
|
per_layer_inputs: Optional[torch.Tensor] = None,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
if self.fast_prefill_enabled:
|
||||||
|
hidden_states = self.fast_prefill_forward(
|
||||||
|
input_ids,
|
||||||
|
positions,
|
||||||
|
inputs_embeds,
|
||||||
|
per_layer_inputs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = self.normal_forward(
|
||||||
|
input_ids,
|
||||||
|
positions,
|
||||||
|
inputs_embeds,
|
||||||
|
per_layer_inputs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = self.altup_unembed(hidden_states)
|
||||||
return self.norm(hidden_states)
|
return self.norm(hidden_states)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
@ -716,6 +995,13 @@ class Gemma3nTextModel(nn.Module, SupportsQuant):
|
|||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
loaded_params: set[str] = set()
|
loaded_params: set[str] = set()
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
|
# decoder layer weights, altup_unembed_projections and rmsnorm
|
||||||
|
# are initialized in text model, others are in self decoder
|
||||||
|
if (not name.startswith('layers')
|
||||||
|
and not name.startswith('altup_unembed_projections')
|
||||||
|
and not name.startswith('norm')):
|
||||||
|
name = f"self_decoder.{name}"
|
||||||
|
|
||||||
if (self.quant_config is not None and
|
if (self.quant_config is not None and
|
||||||
(scale_name := self.quant_config.get_cache_scale(name))):
|
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||||
# Loading kv cache scales for compressed-tensors quantization
|
# Loading kv cache scales for compressed-tensors quantization
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user