mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-05 00:07:06 +08:00
fix VLLM_MLA_PERFORM_MATRIX_ABSORPTION=0
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
parent
ec8c1cf732
commit
f23d126a07
@ -188,8 +188,9 @@ class MLAImplCommon(AttentionImpl):
|
||||
return torch.matmul(x, self.W_Q_UK)\
|
||||
.view(-1, self.num_heads, self.kv_lora_rank)
|
||||
else:
|
||||
x = torch.matmul(x, self.W_Q)
|
||||
return torch.matmul(x, self.W_UK.T)\
|
||||
x = torch.matmul(x, self.W_Q)\
|
||||
.view(-1, self.num_heads, self.qk_nope_head_dim)
|
||||
return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\
|
||||
.view(-1, self.num_heads, self.kv_lora_rank)
|
||||
|
||||
def process_weights_after_loading(self):
|
||||
@ -249,13 +250,15 @@ class MLAImplCommon(AttentionImpl):
|
||||
self.W_UV_O.shape[0] * tp_size,
|
||||
self.W_UV_O.shape[1],
|
||||
bias=False,
|
||||
#quant_config=self.o_proj.quant_method, TODO
|
||||
# TODO(lucas) figure out how to properly forward quant_method
|
||||
#quant_config=self.o_proj.quant_method,
|
||||
)
|
||||
|
||||
self.o_proj_absored.weight = torch.nn.Parameter(self.W_UV_O.T)
|
||||
else:
|
||||
print("Not absorbing weights")
|
||||
self.W_UK, self.W_UV, self.W_Q = W_UK, W_UV, W_Q
|
||||
self.W_UV = W_UV
|
||||
self.W_UK = W_UK
|
||||
self.W_Q = W_Q.flatten(start_dim=1)
|
||||
|
||||
@abstractmethod
|
||||
def _forward_prefill(
|
||||
|
||||
@ -124,7 +124,7 @@ class TritonMLAState(AttentionState):
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class TritonMLAMetadata(MLAMetadataCommon):
|
||||
"""Metadata for FlashAttentionBackend.
|
||||
"""Metadata for TritonMLAMetadata.
|
||||
|
||||
NOTE: Any python object stored here is not updated when it is
|
||||
cuda-graph replayed. If you have values that need to be changed
|
||||
@ -189,7 +189,7 @@ class TritonMLAMetadata(MLAMetadataCommon):
|
||||
|
||||
num_prefill_tokens: int
|
||||
|
||||
num_kv_splits: int = 4
|
||||
num_kv_splits: int = 4 # TODO(lucas) add heuristic
|
||||
attn_logits: Optional[torch.Tensor] = None
|
||||
req_idx: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
@ -512,9 +512,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
"VLLM_V1_OUTPUT_PROC_CHUNK_SIZE":
|
||||
lambda: int(os.getenv("VLLM_V1_OUTPUT_PROC_CHUNK_SIZE", "128")),
|
||||
|
||||
# Flag that can control whether
|
||||
#
|
||||
#
|
||||
# Flag that can control whether or not we perform matrix-absorption for MLA
|
||||
# decode, i.e. absorb W_UK into W_Q/W_UK and W_UV into W_O, absorbing the
|
||||
# matrices reduces the runtime FLOPs needed to compute MLA but requires
|
||||
# storing more weights, W_Q_UK and W_UV_O, so can increase memory usage,
|
||||
# the is enabled by default
|
||||
"VLLM_MLA_PERFORM_MATRIX_ABSORPTION":
|
||||
lambda: bool(int(os.getenv("VLLM_MLA_PERFORM_MATRIX_ABSORPTION", "1")))
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user