[CPU][Bugfix] Fix Apple Silicon M1 compilation failure (#28681)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-11-13 20:49:55 -05:00 committed by GitHub
parent 2aa75c752b
commit 622e6106a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -5,6 +5,10 @@
#include <type_traits>
#include <cstddef>
#if defined(__APPLE__)
#include <sys/sysctl.h>
#endif
#include "cpu_types.hpp"
#include "scratchpad_manager.h"
#include "cpu_attn_macros.h"
@ -741,9 +745,21 @@ class AttentionScheduler {
static int64_t get_available_l2_size() {
static int64_t size = []() {
#if defined(__APPLE__)
// macOS doesn't have _SC_LEVEL2_CACHE_SIZE. Use sysctlbyname.
int64_t l2_cache_size = 0;
size_t len = sizeof(l2_cache_size);
if (sysctlbyname("hw.l2cachesize", &l2_cache_size, &len, NULL, 0) == 0 &&
l2_cache_size > 0) {
return l2_cache_size >> 1; // use 50% of L2 cache
}
// Fallback if sysctlbyname fails
return 128 * 1024 >> 1; // use 50% of 128KB
#else
long l2_cache_size = sysconf(_SC_LEVEL2_CACHE_SIZE);
TORCH_CHECK_NE(l2_cache_size, -1);
return l2_cache_size >> 1; // use 50% of L2 cache
#endif
}();
return size;
}
@ -816,10 +832,14 @@ struct VecTypeTrait<float> {
using vec_t = vec_op::FP32Vec16;
};
// ARM only supports BF16 with ARMv8.6-A extension
#if (defined(__aarch64__) && !defined(ARM_BF16_SUPPORT))
#else
template <>
struct VecTypeTrait<c10::BFloat16> {
using vec_t = vec_op::BF16Vec16;
};
#endif
#if !defined(__powerpc__)
template <>
@ -1588,9 +1608,17 @@ class AttentionMainLoop {
if (use_sink) {
alignas(64) float s_aux_fp32[16];
#if defined(__aarch64__) && !defined(ARM_BF16_SUPPORT)
// ARM without native BF16 support: manual conversion
for (int i = 0; i < 16; ++i) {
s_aux_fp32[i] = static_cast<float>(curr_s_aux[i]);
}
#else
// All other platforms have BF16Vec16 available
vec_op::BF16Vec16 vec_bf16(curr_s_aux);
vec_op::FP32Vec16 vec_fp32(vec_bf16);
vec_fp32.save(s_aux_fp32);
#endif
float* __restrict__ curr_sum_buffer = sum_buffer;
float* __restrict__ curr_max_buffer = max_buffer;