[BugFix][Model] Fix commandr RoPE max_position_embeddings (#3919)

This commit is contained in:
Roy 2024-04-09 06:17:21 +08:00 committed by GitHub
parent 59a6abf3c9
commit d036198e23
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -140,7 +140,9 @@ class CohereAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.max_position_embeddings = config.max_position_embeddings
self.max_position_embeddings = getattr(
config, "model_max_length", None) or getattr(
config, "max_position_embeddings", 8192)
self.rope_theta = config.rope_theta
self.rope_scaling = getattr(config, "rope_scaling", None)
self.use_qk_norm = getattr(config, "use_qk_norm", False)