mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 04:35:01 +08:00
[Bugfix] Fix non-contiguous tensor error in rocm_unquantized_gemm_impl (#27605)
Signed-off-by: zhewenli <zhewenli@meta.com>
This commit is contained in:
parent
83fd49b1fc
commit
8b62495076
@ -286,7 +286,7 @@ steps:
|
|||||||
|
|
||||||
- label: Engine Test # 25min
|
- label: Engine Test # 25min
|
||||||
timeout_in_minutes: 40
|
timeout_in_minutes: 40
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental, amdproduction]
|
||||||
agent_pool: mi325_1
|
agent_pool: mi325_1
|
||||||
#grade: Blocking
|
#grade: Blocking
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
|
|||||||
@ -119,17 +119,17 @@ def rocm_unquantized_gemm_impl(
|
|||||||
if use_skinny is not True:
|
if use_skinny is not True:
|
||||||
return torch.nn.functional.linear(x, weight, bias)
|
return torch.nn.functional.linear(x, weight, bias)
|
||||||
|
|
||||||
x_view = x.view(-1, x.size(-1))
|
x_view = x.reshape(-1, x.size(-1))
|
||||||
n = x_view.shape[0]
|
n = x_view.shape[0]
|
||||||
m = weight.shape[0]
|
m = weight.shape[0]
|
||||||
cu_count = current_platform.get_cu_count()
|
cu_count = current_platform.get_cu_count()
|
||||||
|
|
||||||
if m > 8 and 0 < n <= 4:
|
if m > 8 and 0 < n <= 4:
|
||||||
out = ops.wvSplitK(weight, x_view, cu_count, bias)
|
out = ops.wvSplitK(weight, x_view, cu_count, bias)
|
||||||
return out.view(*x.shape[:-1], weight.shape[0])
|
return out.reshape(*x.shape[:-1], weight.shape[0])
|
||||||
elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None:
|
elif m % 4 == 0 and n == 1 and k <= 8192 and bias is None:
|
||||||
out = ops.LLMM1(weight, x_view, 4)
|
out = ops.LLMM1(weight, x_view, 4)
|
||||||
return out.view(*x.shape[:-1], weight.shape[0])
|
return out.reshape(*x.shape[:-1], weight.shape[0])
|
||||||
return torch.nn.functional.linear(x, weight, bias)
|
return torch.nn.functional.linear(x, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user