diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index e1dc711778ff..486ebe1d464c 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -43,8 +43,8 @@ __global__ void rotary_embedding_kernel( scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] const int rot_dim, - const int query_stride, - const int key_stride, + const int64_t query_stride, + const int64_t key_stride, const int num_heads, const int num_kv_heads, const int head_size) { @@ -60,7 +60,7 @@ __global__ void rotary_embedding_kernel( const int nq = num_heads * embed_dim; for (int i = threadIdx.x; i < nq; i += blockDim.x) { const int head_idx = i / embed_dim; - const int token_head = token_idx * query_stride + head_idx * head_size; + const int64_t token_head = token_idx * query_stride + head_idx * head_size; const int rot_offset = i % embed_dim; apply_rotary_embedding(query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); @@ -69,7 +69,7 @@ __global__ void rotary_embedding_kernel( const int nk = num_kv_heads * embed_dim; for (int i = threadIdx.x; i < nk; i += blockDim.x) { const int head_idx = i / embed_dim; - const int token_head = token_idx * key_stride + head_idx * head_size; + const int64_t token_head = token_idx * key_stride + head_idx * head_size; const int rot_offset = i % embed_dim; apply_rotary_embedding(key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); @@ -89,8 +89,8 @@ void rotary_embedding( int rot_dim = cos_sin_cache.size(1); int num_heads = query.size(-1) / head_size; int num_kv_heads = key.size(-1) / head_size; - int query_stride = query.stride(-2); - int key_stride = key.stride(-2); + int64_t query_stride = query.stride(-2); + int64_t key_stride = key.stride(-2); dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512));