Fix non-contiguous input passed to Marlin kernel (#15319)

This commit is contained in:
Qubitium-ModelCloud 2025-03-24 11:09:44 +08:00 committed by GitHub
parent f622dbcf39
commit d20e261199
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -115,6 +115,10 @@ class MarlinLinearKernel(MPLinearKernel):
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# marlin requires contiguous memory layout
# prefix caching may cause x to be non-contiguous
x = x.contiguous() # no-op if already contiguous
c = self.config
w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer)