[MISC] More AMD unused var clean up (#14926)

Signed-off-by: Lu Fang <lufang@fb.com>
This commit is contained in:
Lu Fang 2025-03-17 01:40:41 -07:00 committed by GitHub
parent 0a74bfce9c
commit cd0cd85102
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -127,7 +127,7 @@ __device__ __forceinline__ T from_float(const float& inp) {
template <typename T>
__device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp) {
union tmpcvt {
[[maybe_unused]] union tmpcvt {
uint16_t u;
_Float16 f;
__hip_bfloat16 b;
@ -160,7 +160,7 @@ __device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp) {
template <typename T>
__device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1,
const _B16x4& inp2) {
union tmpcvt {
[[maybe_unused]] union tmpcvt {
uint16_t u;
_Float16 f;
__hip_bfloat16 b;
@ -1273,9 +1273,9 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
const int seq_idx = blockIdx.y;
const int context_len = context_lens[seq_idx];
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
[[maybe_unused]] constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int warpid = threadIdx.x / WARP_SIZE;
const int laneid = threadIdx.x % WARP_SIZE;
[[maybe_unused]] const int laneid = threadIdx.x % WARP_SIZE;
__shared__ float shared_global_exp_sum;
// max num partitions supported is warp_size * NPAR_LOOPS