Enable CUDA graph for GPTQ & SqueezeLLM (#2318)

This commit is contained in:
Woosuk Kwon 2024-01-03 09:52:29 -08:00 committed by GitHub
parent 9140561059
commit 6ef00b03a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 15 additions and 13 deletions

View File

@ -287,7 +287,8 @@ void gemm_half_q_half_cuda_part
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count);
kernel<<<gridDim, blockDim>>>
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
kernel<<<gridDim, blockDim, 0, stream>>>
(
a,
b_q_weight,
@ -434,7 +435,8 @@ void reconstruct_exllama
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
reconstruct_exllama_kernel<<<gridDim, blockDim>>>
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
reconstruct_exllama_kernel<<<gridDim, blockDim, 0, stream>>>
(
b_q_weight,
b_q_perm,
@ -567,7 +569,8 @@ void gemm_half_q_half_alt
gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX);
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
gemm_half_q_half_alt_kernel<<<gridDim, blockDim>>>
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
gemm_half_q_half_alt_kernel<<<gridDim, blockDim, 0, stream>>>
(
(const half2*) a,
b_q_weight,
@ -639,7 +642,8 @@ void reconstruct_gptq
blockDim.y = 1;
gridDim.y = DIVIDE(height, 8);
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
reconstruct_gptq_kernel<<<gridDim, blockDim>>>
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
reconstruct_gptq_kernel<<<gridDim, blockDim, 0, stream>>>
(
b_q_weight,
b_gptq_scales,
@ -794,7 +798,8 @@ void shuffle_exllama_weight
gridDim.x = DIVIDE(width, THREADS_X);
gridDim.y = height / 8;
make_sequential_kernel<<<gridDim, blockDim>>>
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
make_sequential_kernel<<<gridDim, blockDim, 0, stream>>>
(
q_weight,
new_qweight,
@ -813,7 +818,8 @@ void shuffle_exllama_weight
blockDim.y = 1;
gridDim.x = DIVIDE(width, THREADS_X);
gridDim.y = 1;
shuffle_kernel<<<gridDim, blockDim>>>(q_weight, height, width);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(q_weight, height, width);
}
} // namespace gptq

View File

@ -200,8 +200,10 @@ void squeezellm_gemm(
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
);
dim3 threads(BLOCKWIDTH);
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads>>>(
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>(
#ifndef USE_ROCM
(half2*) vec.data<at::Half>(),
#else

View File

@ -181,12 +181,6 @@ class ModelConfig:
self.max_context_len_to_capture = self.max_model_len
self.max_context_len_to_capture = min(self.max_context_len_to_capture,
self.max_model_len)
if (self.quantization in ["gptq", "squeezellm"]
and not self.enforce_eager):
# Related issue: https://github.com/vllm-project/vllm/issues/2147
logger.warning(f"{self.quantization} does not support CUDA graph "
"yet. Disabling CUDA graph.")
self.enforce_eager = True
def verify_with_parallel_config(
self,