mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 04:22:15 +08:00
[Bugfix] Fix cutlass dispatch for fp8/int8 to properly invoke M<=16 c… (#16751)
Signed-off-by: Ther-LF <2639852836@qq.com>
This commit is contained in:
parent
d1aeea7553
commit
c12df53b60
@ -336,7 +336,7 @@ inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out,
|
|||||||
|
|
||||||
uint32_t const m = a.size(0);
|
uint32_t const m = a.size(0);
|
||||||
uint32_t const mp2 =
|
uint32_t const mp2 =
|
||||||
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
|
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
|
||||||
|
|
||||||
if (mp2 <= 16) {
|
if (mp2 <= 16) {
|
||||||
// M in [1, 16]
|
// M in [1, 16]
|
||||||
|
|||||||
@ -321,7 +321,7 @@ inline void cutlass_gemm_sm89_int8_dispatch(torch::Tensor& out,
|
|||||||
|
|
||||||
uint32_t const m = a.size(0);
|
uint32_t const m = a.size(0);
|
||||||
uint32_t const mp2 =
|
uint32_t const mp2 =
|
||||||
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
|
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
|
||||||
|
|
||||||
if (mp2 <= 16) {
|
if (mp2 <= 16) {
|
||||||
// M in [1, 16]
|
// M in [1, 16]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user