[Bugfix] Fix Plamo3 rope handling (#29092)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Cyrus Leung 2025-11-21 11:38:35 +08:00 committed by GitHub
parent 56669c1f29
commit 0e741c12e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -62,7 +62,7 @@ class Plamo3Config(PretrainedConfig): # type: ignore
# if `sliding_window` is list
interleaved_sliding_window: list[int | None]
sliding_window_pattern: int
rope_theta: int
rope_parameters: dict[str, Any]
rope_local_theta: int
# MLP
intermediate_size: int
@ -153,13 +153,24 @@ class Plamo3AttentionMixer(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
layer_idx = extract_layer_index(prefix)
full_attn = config.interleaved_sliding_window[layer_idx] is None
self.rope_theta = config.rope_theta if full_attn else config.rope_local_theta
self.rope_scaling = (
config.rope_scaling if hasattr(config, "rope_scaling") else None
)
layer_idx = extract_layer_index(prefix)
layer_type = config.layer_types[layer_idx]
is_sliding = layer_type == "sliding_attention"
# Initialize the rotary embedding.
if layer_type in config.rope_parameters:
# Transformers v5 rope config.
rope_parameters = config.rope_parameters[layer_type]
else:
# Transformers v4 rope config.
# Global attention. Use the values in config.json.
rope_parameters = config.rope_parameters
# Local attention. Override the values in config.json.
if is_sliding:
rope_parameters = dict(
rope_type="default", rope_theta=config.rope_local_theta
)
max_position = config.max_position_embeddings
if hasattr(vllm_config.model_config, "max_model_len") and isinstance(
vllm_config.model_config.max_model_len, int
@ -170,8 +181,7 @@ class Plamo3AttentionMixer(nn.Module):
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=self.rope_theta,
rope_scaling=self.rope_scaling,
rope_parameters=rope_parameters,
)
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
set_weight_attrs(