Update utils.py

Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Lucas Wilkinson 2025-01-29 22:46:43 -05:00 committed by GitHub
parent 3895bba85a
commit 4880a43d20
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -231,7 +231,7 @@ class MLAImplCommon(AttentionImpl):
.flatten(start_dim=0, end_dim=1).contiguous()
tp_size = get_tensor_model_parallel_world_size()
self.o_proj_absored = RowParallelLinear(
self.o_proj_absorbed = RowParallelLinear(
self.W_UV_O.shape[0] * tp_size,
self.W_UV_O.shape[1],
bias=False,
@ -239,7 +239,7 @@ class MLAImplCommon(AttentionImpl):
#quant_config=self.o_proj.quant_method,
)
self.o_proj_absored.weight = torch.nn.Parameter(self.W_UV_O.T)
self.o_proj_absorbed.weight = torch.nn.Parameter(self.W_UV_O.T)
else:
self.W_UV = W_UV
self.W_UK = W_UK