mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 02:45:37 +08:00
Enable CUDA graph for GPTQ & SqueezeLLM (#2318)
This commit is contained in:
parent
9140561059
commit
6ef00b03a2
@ -287,7 +287,8 @@ void gemm_half_q_half_cuda_part
|
|||||||
|
|
||||||
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count);
|
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count);
|
||||||
|
|
||||||
kernel<<<gridDim, blockDim>>>
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
kernel<<<gridDim, blockDim, 0, stream>>>
|
||||||
(
|
(
|
||||||
a,
|
a,
|
||||||
b_q_weight,
|
b_q_weight,
|
||||||
@ -434,7 +435,8 @@ void reconstruct_exllama
|
|||||||
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
|
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
|
||||||
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
|
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
|
||||||
|
|
||||||
reconstruct_exllama_kernel<<<gridDim, blockDim>>>
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
reconstruct_exllama_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||||
(
|
(
|
||||||
b_q_weight,
|
b_q_weight,
|
||||||
b_q_perm,
|
b_q_perm,
|
||||||
@ -567,7 +569,8 @@ void gemm_half_q_half_alt
|
|||||||
gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX);
|
gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX);
|
||||||
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
|
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
|
||||||
|
|
||||||
gemm_half_q_half_alt_kernel<<<gridDim, blockDim>>>
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
gemm_half_q_half_alt_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||||
(
|
(
|
||||||
(const half2*) a,
|
(const half2*) a,
|
||||||
b_q_weight,
|
b_q_weight,
|
||||||
@ -639,7 +642,8 @@ void reconstruct_gptq
|
|||||||
blockDim.y = 1;
|
blockDim.y = 1;
|
||||||
gridDim.y = DIVIDE(height, 8);
|
gridDim.y = DIVIDE(height, 8);
|
||||||
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
|
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
|
||||||
reconstruct_gptq_kernel<<<gridDim, blockDim>>>
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
reconstruct_gptq_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||||
(
|
(
|
||||||
b_q_weight,
|
b_q_weight,
|
||||||
b_gptq_scales,
|
b_gptq_scales,
|
||||||
@ -794,7 +798,8 @@ void shuffle_exllama_weight
|
|||||||
gridDim.x = DIVIDE(width, THREADS_X);
|
gridDim.x = DIVIDE(width, THREADS_X);
|
||||||
gridDim.y = height / 8;
|
gridDim.y = height / 8;
|
||||||
|
|
||||||
make_sequential_kernel<<<gridDim, blockDim>>>
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
make_sequential_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||||
(
|
(
|
||||||
q_weight,
|
q_weight,
|
||||||
new_qweight,
|
new_qweight,
|
||||||
@ -813,7 +818,8 @@ void shuffle_exllama_weight
|
|||||||
blockDim.y = 1;
|
blockDim.y = 1;
|
||||||
gridDim.x = DIVIDE(width, THREADS_X);
|
gridDim.x = DIVIDE(width, THREADS_X);
|
||||||
gridDim.y = 1;
|
gridDim.y = 1;
|
||||||
shuffle_kernel<<<gridDim, blockDim>>>(q_weight, height, width);
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(q_weight, height, width);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gptq
|
} // namespace gptq
|
||||||
|
|||||||
@ -200,8 +200,10 @@ void squeezellm_gemm(
|
|||||||
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
|
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
|
||||||
);
|
);
|
||||||
dim3 threads(BLOCKWIDTH);
|
dim3 threads(BLOCKWIDTH);
|
||||||
|
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
|
||||||
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads>>>(
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>(
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
(half2*) vec.data<at::Half>(),
|
(half2*) vec.data<at::Half>(),
|
||||||
#else
|
#else
|
||||||
|
|||||||
@ -181,12 +181,6 @@ class ModelConfig:
|
|||||||
self.max_context_len_to_capture = self.max_model_len
|
self.max_context_len_to_capture = self.max_model_len
|
||||||
self.max_context_len_to_capture = min(self.max_context_len_to_capture,
|
self.max_context_len_to_capture = min(self.max_context_len_to_capture,
|
||||||
self.max_model_len)
|
self.max_model_len)
|
||||||
if (self.quantization in ["gptq", "squeezellm"]
|
|
||||||
and not self.enforce_eager):
|
|
||||||
# Related issue: https://github.com/vllm-project/vllm/issues/2147
|
|
||||||
logger.warning(f"{self.quantization} does not support CUDA graph "
|
|
||||||
"yet. Disabling CUDA graph.")
|
|
||||||
self.enforce_eager = True
|
|
||||||
|
|
||||||
def verify_with_parallel_config(
|
def verify_with_parallel_config(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user