[Bugfix] Fix non-contiguous tensor error in rocm_unquantized_gemm_impl (#27605)

Signed-off-by: zhewenli <zhewenli@meta.com>
This commit is contained in:
Zhewen Li 2025-10-29 00:00:15 -07:00 committed by GitHub
parent 83fd49b1fc
commit 8b62495076
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 4 deletions

View File

@ -286,7 +286,7 @@ steps:
- label: Engine Test # 25min
timeout_in_minutes: 40
mirror_hardwares: [amdexperimental]
mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_1
#grade: Blocking
source_file_dependencies:

View File

@ -119,17 +119,17 @@ def rocm_unquantized_gemm_impl(
if use_skinny is not True:
return torch.nn.functional.linear(x, weight, bias)
x_view = x.view(-1, x.size(-1))
x_view = x.reshape(-1, x.size(-1))
n = x_view.shape[0]
m = weight.shape[0]
cu_count = current_platform.get_cu_count()
if m > 8 and 0 < n <= 4:
out = ops.wvSplitK(weight, x_view, cu_count, bias)
return out.view(*x.shape[:-1], weight.shape[0])
return out.reshape(*x.shape[:-1], weight.shape[0])
elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None:
out = ops.LLMM1(weight, x_view, 4)
return out.view(*x.shape[:-1], weight.shape[0])
return out.reshape(*x.shape[:-1], weight.shape[0])
return torch.nn.functional.linear(x, weight, bias)