mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-24 22:39:08 +08:00
Update utils.py
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
parent
3895bba85a
commit
4880a43d20
@ -231,7 +231,7 @@ class MLAImplCommon(AttentionImpl):
|
||||
.flatten(start_dim=0, end_dim=1).contiguous()
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.o_proj_absored = RowParallelLinear(
|
||||
self.o_proj_absorbed = RowParallelLinear(
|
||||
self.W_UV_O.shape[0] * tp_size,
|
||||
self.W_UV_O.shape[1],
|
||||
bias=False,
|
||||
@ -239,7 +239,7 @@ class MLAImplCommon(AttentionImpl):
|
||||
#quant_config=self.o_proj.quant_method,
|
||||
)
|
||||
|
||||
self.o_proj_absored.weight = torch.nn.Parameter(self.W_UV_O.T)
|
||||
self.o_proj_absorbed.weight = torch.nn.Parameter(self.W_UV_O.T)
|
||||
else:
|
||||
self.W_UV = W_UV
|
||||
self.W_UK = W_UK
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user