mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:25:00 +08:00
Minor fix on AWQ kernel launch (#1356)
This commit is contained in:
parent
d0740dff1b
commit
29678cd213
@ -534,6 +534,7 @@ torch::Tensor awq_gemm(
|
|||||||
if (num_out_channels % group_size != 0)
|
if (num_out_channels % group_size != 0)
|
||||||
throw std::invalid_argument("OC is not multiple of Group size");
|
throw std::invalid_argument("OC is not multiple of Group size");
|
||||||
|
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
if (num_out_channels % 128 == 0)
|
if (num_out_channels % 128 == 0)
|
||||||
{
|
{
|
||||||
int j_factors1 = num_out_channels / 128 / 1;
|
int j_factors1 = num_out_channels / 128 / 1;
|
||||||
@ -541,7 +542,7 @@ torch::Tensor awq_gemm(
|
|||||||
// threadIdx.x: 32
|
// threadIdx.x: 32
|
||||||
// threadIdx.y: i_factors[2] * j_factors[2]
|
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||||
dim3 threads_per_block(32, 2);
|
dim3 threads_per_block(32, 2);
|
||||||
vllm::awq::gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>>(
|
vllm::awq::gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block, 0, stream>>>(
|
||||||
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
|
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
|
||||||
}
|
}
|
||||||
else if (num_out_channels % 64 == 0)
|
else if (num_out_channels % 64 == 0)
|
||||||
@ -552,7 +553,7 @@ torch::Tensor awq_gemm(
|
|||||||
// threadIdx.x: 32
|
// threadIdx.x: 32
|
||||||
// threadIdx.y: i_factors[2] * j_factors[2]
|
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||||
dim3 threads_per_block(32, 2);
|
dim3 threads_per_block(32, 2);
|
||||||
vllm::awq::gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block>>>(
|
vllm::awq::gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block, 0, stream>>>(
|
||||||
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
|
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
|
||||||
}
|
}
|
||||||
return _out_feats.sum(0);
|
return _out_feats.sum(0);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user