diff --git a/csrc/quantization/awq/gemm_kernels.cu b/csrc/quantization/awq/gemm_kernels.cu index 63542990b80f..04dfe8fe9b88 100644 --- a/csrc/quantization/awq/gemm_kernels.cu +++ b/csrc/quantization/awq/gemm_kernels.cu @@ -534,6 +534,7 @@ torch::Tensor awq_gemm( if (num_out_channels % group_size != 0) throw std::invalid_argument("OC is not multiple of Group size"); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if (num_out_channels % 128 == 0) { int j_factors1 = num_out_channels / 128 / 1; @@ -541,18 +542,18 @@ torch::Tensor awq_gemm( // threadIdx.x: 32 // threadIdx.y: i_factors[2] * j_factors[2] dim3 threads_per_block(32, 2); - vllm::awq::gemm_forward_4bit_cuda_m16n128k32<<>>( + vllm::awq::gemm_forward_4bit_cuda_m16n128k32<<>>( group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats); } else if (num_out_channels % 64 == 0) { - int j_factors1 = num_out_channels / 64 / 1; + int j_factors1 = num_out_channels / 64 / 1; dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); // threadIdx.x: 32 // threadIdx.y: i_factors[2] * j_factors[2] dim3 threads_per_block(32, 2); - vllm::awq::gemm_forward_4bit_cuda_m16n64k32<<>>( + vllm::awq::gemm_forward_4bit_cuda_m16n64k32<<>>( group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats); } return _out_feats.sum(0);