From ded8ada86a3962477433054debbcef1d45161850 Mon Sep 17 00:00:00 2001 From: Bram Wasti Date: Thu, 30 Oct 2025 01:28:45 -0400 Subject: [PATCH] Add more dims for batch invariant shims (#27489) Signed-off-by: Bram Wasti Signed-off-by: Bram Wasti Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- vllm/model_executor/layers/batch_invariant.py | 44 ++++++++++++++++++- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py index 208ffb30e5ed2..5706786bccb1d 100644 --- a/vllm/model_executor/layers/batch_invariant.py +++ b/vllm/model_executor/layers/batch_invariant.py @@ -478,9 +478,48 @@ def matmul_batch_invariant(a, b, *, out=None): elif a.ndim == 3 and b.ndim == 3: # Handle batched case like bmm return bmm_batch_invariant(a, b, out=out) + elif a.ndim == 3 and b.ndim == 2: + # Handle 3D x 2D: common for linear layers + # (batch, seq, hidden) @ (hidden, out) -> (batch, seq, out) + # Reshape to 2D, do mm, reshape back + batch, seq, hidden = a.shape + a_2d = a.reshape(-1, hidden) + result_2d = matmul_persistent(a_2d, b) + result = result_2d.reshape(batch, seq, -1) + if out is not None: + out.copy_(result) + return out + return result + elif a.ndim == 2 and b.ndim == 3: + # Handle 2D x 3D: (M, K) @ (B, K, N) -> (B, M, N) + # By broadcasting `a` to 3D, we can reuse the batched matrix + # multiplication logic. + a_expanded = a.unsqueeze(0).expand(b.shape[0], -1, -1) + return bmm_batch_invariant(a_expanded, b, out=out) + elif a.ndim == 4 and b.ndim == 4: + # Handle 4D attention tensors: [batch, heads, seq, dim] + # Reshape to 3D, process, reshape back + batch, heads, seq_a, dim_a = a.shape + _, _, dim_b, seq_b = b.shape + + # Reshape to [batch*heads, seq_a, dim_a] + a_3d = a.reshape(batch * heads, seq_a, dim_a) + b_3d = b.reshape(batch * heads, dim_b, seq_b) + + # Do batched matmul + result_3d = bmm_batch_invariant(a_3d, b_3d) + + # Reshape back to [batch, heads, seq_a, seq_b] + result = result_3d.reshape(batch, heads, seq_a, seq_b) + + if out is not None: + out.copy_(result) + return out + return result else: raise ValueError( - f"matmul_batch_invariant currently only supports 2D x 2D and 3D x 3D, " + f"matmul_batch_invariant currently only supports 2D x 2D, 3D x 3D, " + f"3D x 2D, 2D x 3D, and 4D x 4D, " f"got shapes {a.shape} and {b.shape}" ) @@ -667,7 +706,8 @@ def rms_norm_batch_invariant( def linear_batch_invariant(input, weight, bias=None): - output = mm_batch_invariant(input, weight.t()) + output = matmul_batch_invariant(input, weight.t()) + if bias is not None: output = output + bias return output