mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 15:47:22 +08:00
[Bugfix] Fix Plamo3 rope handling (#29092)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
56669c1f29
commit
0e741c12e3
@ -62,7 +62,7 @@ class Plamo3Config(PretrainedConfig): # type: ignore
|
||||
# if `sliding_window` is list
|
||||
interleaved_sliding_window: list[int | None]
|
||||
sliding_window_pattern: int
|
||||
rope_theta: int
|
||||
rope_parameters: dict[str, Any]
|
||||
rope_local_theta: int
|
||||
# MLP
|
||||
intermediate_size: int
|
||||
@ -153,13 +153,24 @@ class Plamo3AttentionMixer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
full_attn = config.interleaved_sliding_window[layer_idx] is None
|
||||
|
||||
self.rope_theta = config.rope_theta if full_attn else config.rope_local_theta
|
||||
self.rope_scaling = (
|
||||
config.rope_scaling if hasattr(config, "rope_scaling") else None
|
||||
)
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
layer_type = config.layer_types[layer_idx]
|
||||
is_sliding = layer_type == "sliding_attention"
|
||||
|
||||
# Initialize the rotary embedding.
|
||||
if layer_type in config.rope_parameters:
|
||||
# Transformers v5 rope config.
|
||||
rope_parameters = config.rope_parameters[layer_type]
|
||||
else:
|
||||
# Transformers v4 rope config.
|
||||
# Global attention. Use the values in config.json.
|
||||
rope_parameters = config.rope_parameters
|
||||
# Local attention. Override the values in config.json.
|
||||
if is_sliding:
|
||||
rope_parameters = dict(
|
||||
rope_type="default", rope_theta=config.rope_local_theta
|
||||
)
|
||||
max_position = config.max_position_embeddings
|
||||
if hasattr(vllm_config.model_config, "max_model_len") and isinstance(
|
||||
vllm_config.model_config.max_model_len, int
|
||||
@ -170,8 +181,7 @@ class Plamo3AttentionMixer(nn.Module):
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position,
|
||||
base=self.rope_theta,
|
||||
rope_scaling=self.rope_scaling,
|
||||
rope_parameters=rope_parameters,
|
||||
)
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||
set_weight_attrs(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user