From 6ef00b03a2b6679d494530e9a98932c5a6cc8418 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 3 Jan 2024 09:52:29 -0800 Subject: [PATCH] Enable CUDA graph for GPTQ & SqueezeLLM (#2318) --- csrc/quantization/gptq/q_gemm.cu | 18 ++++++++++++------ .../squeezellm/quant_cuda_kernel.cu | 4 +++- vllm/config.py | 6 ------ 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu index eb0d75f1293c..a5d2345f1e7f 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/quantization/gptq/q_gemm.cu @@ -287,7 +287,8 @@ void gemm_half_q_half_cuda_part fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count); - kernel<<>> + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + kernel<<>> ( a, b_q_weight, @@ -434,7 +435,8 @@ void reconstruct_exllama gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); - reconstruct_exllama_kernel<<>> + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + reconstruct_exllama_kernel<<>> ( b_q_weight, b_q_perm, @@ -567,7 +569,8 @@ void gemm_half_q_half_alt gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX); gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); - gemm_half_q_half_alt_kernel<<>> + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + gemm_half_q_half_alt_kernel<<>> ( (const half2*) a, b_q_weight, @@ -639,7 +642,8 @@ void reconstruct_gptq blockDim.y = 1; gridDim.y = DIVIDE(height, 8); gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); - reconstruct_gptq_kernel<<>> + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + reconstruct_gptq_kernel<<>> ( b_q_weight, b_gptq_scales, @@ -794,7 +798,8 @@ void shuffle_exllama_weight gridDim.x = DIVIDE(width, THREADS_X); gridDim.y = height / 8; - make_sequential_kernel<<>> + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + make_sequential_kernel<<>> ( q_weight, new_qweight, @@ -813,7 +818,8 @@ void shuffle_exllama_weight blockDim.y = 1; gridDim.x = DIVIDE(width, THREADS_X); gridDim.y = 1; - shuffle_kernel<<>>(q_weight, height, width); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + shuffle_kernel<<>>(q_weight, height, width); } } // namespace gptq diff --git a/csrc/quantization/squeezellm/quant_cuda_kernel.cu b/csrc/quantization/squeezellm/quant_cuda_kernel.cu index b17ced6fce79..09964903622b 100644 --- a/csrc/quantization/squeezellm/quant_cuda_kernel.cu +++ b/csrc/quantization/squeezellm/quant_cuda_kernel.cu @@ -200,8 +200,10 @@ void squeezellm_gemm( (width + BLOCKWIDTH - 1) / BLOCKWIDTH ); dim3 threads(BLOCKWIDTH); + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); - vllm::squeezellm::NUQ4MatMulKernel<<>>( + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + vllm::squeezellm::NUQ4MatMulKernel<<>>( #ifndef USE_ROCM (half2*) vec.data(), #else diff --git a/vllm/config.py b/vllm/config.py index ff9a1308a5c8..f1efcc66e909 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -181,12 +181,6 @@ class ModelConfig: self.max_context_len_to_capture = self.max_model_len self.max_context_len_to_capture = min(self.max_context_len_to_capture, self.max_model_len) - if (self.quantization in ["gptq", "squeezellm"] - and not self.enforce_eager): - # Related issue: https://github.com/vllm-project/vllm/issues/2147 - logger.warning(f"{self.quantization} does not support CUDA graph " - "yet. Disabling CUDA graph.") - self.enforce_eager = True def verify_with_parallel_config( self,