[Kernel] Remove if-else with identical branches in marlin 2:4 (#10687)

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Tyler Michael Smith 2024-11-27 01:55:32 -05:00 committed by GitHub
parent 15cc2a9f1a
commit e2251109c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -296,13 +296,9 @@ __global__ void Marlin_24(
// We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case.
if (group_blocks != -1) {
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 4;
} else {
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 4;
}
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) / 4; // Note that in the original Marlin kernel
// this is (threadIdx.x % 32) / 4
// Precompute which thread should not read memory in which iterations; this is
// needed if there are more threads than required for a certain tilesize or