fix VLLM_MLA_PERFORM_MATRIX_ABSORPTION=0

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
Lucas Wilkinson 2025-01-30 03:02:58 +00:00
parent ec8c1cf732
commit f23d126a07
3 changed files with 15 additions and 10 deletions

View File

@ -188,8 +188,9 @@ class MLAImplCommon(AttentionImpl):
return torch.matmul(x, self.W_Q_UK)\
.view(-1, self.num_heads, self.kv_lora_rank)
else:
x = torch.matmul(x, self.W_Q)
return torch.matmul(x, self.W_UK.T)\
x = torch.matmul(x, self.W_Q)\
.view(-1, self.num_heads, self.qk_nope_head_dim)
return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\
.view(-1, self.num_heads, self.kv_lora_rank)
def process_weights_after_loading(self):
@ -249,13 +250,15 @@ class MLAImplCommon(AttentionImpl):
self.W_UV_O.shape[0] * tp_size,
self.W_UV_O.shape[1],
bias=False,
#quant_config=self.o_proj.quant_method, TODO
# TODO(lucas) figure out how to properly forward quant_method
#quant_config=self.o_proj.quant_method,
)
self.o_proj_absored.weight = torch.nn.Parameter(self.W_UV_O.T)
else:
print("Not absorbing weights")
self.W_UK, self.W_UV, self.W_Q = W_UK, W_UV, W_Q
self.W_UV = W_UV
self.W_UK = W_UK
self.W_Q = W_Q.flatten(start_dim=1)
@abstractmethod
def _forward_prefill(

View File

@ -124,7 +124,7 @@ class TritonMLAState(AttentionState):
@dataclass(kw_only=True)
class TritonMLAMetadata(MLAMetadataCommon):
"""Metadata for FlashAttentionBackend.
"""Metadata for TritonMLAMetadata.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
@ -189,7 +189,7 @@ class TritonMLAMetadata(MLAMetadataCommon):
num_prefill_tokens: int
num_kv_splits: int = 4
num_kv_splits: int = 4 # TODO(lucas) add heuristic
attn_logits: Optional[torch.Tensor] = None
req_idx: Optional[torch.Tensor] = None

View File

@ -512,9 +512,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_V1_OUTPUT_PROC_CHUNK_SIZE":
lambda: int(os.getenv("VLLM_V1_OUTPUT_PROC_CHUNK_SIZE", "128")),
# Flag that can control whether
#
#
# Flag that can control whether or not we perform matrix-absorption for MLA
# decode, i.e. absorb W_UK into W_Q/W_UK and W_UV into W_O, absorbing the
# matrices reduces the runtime FLOPs needed to compute MLA but requires
# storing more weights, W_Q_UK and W_UV_O, so can increase memory usage,
# the is enabled by default
"VLLM_MLA_PERFORM_MATRIX_ABSORPTION":
lambda: bool(int(os.getenv("VLLM_MLA_PERFORM_MATRIX_ABSORPTION", "1")))
}