[Kernel] Squash a few more warnings (#6914)

This commit is contained in:
Tyler Michael Smith 2024-07-30 13:50:42 -04:00 committed by GitHub
parent 5cf9254a9c
commit cbbc904470
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 8 additions and 5 deletions

View File

@ -706,7 +706,7 @@ void paged_attention_v1_launcher(
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
[[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0);
// NOTE: alibi_slopes is optional.
@ -865,7 +865,7 @@ void paged_attention_v2_launcher(
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
[[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0);
// NOTE: alibi_slopes is optional.

View File

@ -273,8 +273,6 @@ __global__ void Code2x8Dequant(
}
__syncthreads();
float res = 0;
int iters = (prob_k / 8 - 1) / (8 * 32) + 1;
while (iters--) {
if (pred && a_gl_rd < a_gl_end) {

View File

@ -526,6 +526,7 @@ __inline__ __device__ Tout convert(const Tin& x) {
}
#endif
assert(false);
return {}; // Squash missing return statement warning
}
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
@ -536,6 +537,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
}
#endif
assert(false);
return {}; // Squash missing return statement warning
}
// The following macro is used to dispatch the conversion function based on

View File

@ -508,6 +508,7 @@ __inline__ __device__ Tout convert(const Tin& x) {
}
#endif
assert(false);
return {}; // Squash missing return statement warning
}
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
@ -520,6 +521,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
}
#endif
assert(false);
return {}; // Squash missing return statement warning
}
// The following macro is used to dispatch the conversion function based on

View File

@ -203,7 +203,8 @@ void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
#endif
mat.data_ptr<int>(),
#ifndef USE_ROCM
(half2*)mul.data<at::Half>(), (__half*)lookup_table.data_ptr<at::Half>(),
(half2*)mul.data_ptr<at::Half>(),
(__half*)lookup_table.data_ptr<at::Half>(),
#else
(float2*)mul.data_ptr<float>(),
(__half*)lookup_table.data_ptr<at::Half>(),