diff --git a/vllm/model_executor/models/plamo3.py b/vllm/model_executor/models/plamo3.py index 5bb07722a5fc1..4aeb9d432dcc6 100644 --- a/vllm/model_executor/models/plamo3.py +++ b/vllm/model_executor/models/plamo3.py @@ -62,7 +62,7 @@ class Plamo3Config(PretrainedConfig): # type: ignore # if `sliding_window` is list interleaved_sliding_window: list[int | None] sliding_window_pattern: int - rope_theta: int + rope_parameters: dict[str, Any] rope_local_theta: int # MLP intermediate_size: int @@ -153,13 +153,24 @@ class Plamo3AttentionMixer(nn.Module): quant_config=quant_config, prefix=f"{prefix}.o_proj", ) - layer_idx = extract_layer_index(prefix) - full_attn = config.interleaved_sliding_window[layer_idx] is None - self.rope_theta = config.rope_theta if full_attn else config.rope_local_theta - self.rope_scaling = ( - config.rope_scaling if hasattr(config, "rope_scaling") else None - ) + layer_idx = extract_layer_index(prefix) + layer_type = config.layer_types[layer_idx] + is_sliding = layer_type == "sliding_attention" + + # Initialize the rotary embedding. + if layer_type in config.rope_parameters: + # Transformers v5 rope config. + rope_parameters = config.rope_parameters[layer_type] + else: + # Transformers v4 rope config. + # Global attention. Use the values in config.json. + rope_parameters = config.rope_parameters + # Local attention. Override the values in config.json. + if is_sliding: + rope_parameters = dict( + rope_type="default", rope_theta=config.rope_local_theta + ) max_position = config.max_position_embeddings if hasattr(vllm_config.model_config, "max_model_len") and isinstance( vllm_config.model_config.max_model_len, int @@ -170,8 +181,7 @@ class Plamo3AttentionMixer(nn.Module): self.head_dim, rotary_dim=self.head_dim, max_position=max_position, - base=self.rope_theta, - rope_scaling=self.rope_scaling, + rope_parameters=rope_parameters, ) self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) set_weight_attrs(