fix stablelm.py tensor-parallel-size bug (#2482)

This commit is contained in:
YingchaoX 2024-01-19 01:39:46 +08:00 committed by GitHub
parent d10f8e1d43
commit 8a25d3a71a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -99,7 +99,7 @@ class StablelmAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_key_value_heads * self.head_dim
if (self.head_dim * self.num_heads) != self.hidden_size:
if (self.head_dim * self.num_heads * tp_size) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads}).")