diff --git a/tests/kernels/test_triton_flash_attention.py b/tests/kernels/test_triton_flash_attention.py new file mode 100644 index 0000000000000..cf2bdc908e420 --- /dev/null +++ b/tests/kernels/test_triton_flash_attention.py @@ -0,0 +1,499 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the triton_flash_attention kernel + +Run `pytest tests/kernels/test_triton_flash_attention.py`. +""" +import pytest +import torch + +from vllm.attention.ops.triton_flash_attention import (SUPPORTED_LAYOUTS, + MetaData, + compute_alibi_tensor, + scale_fp8, + triton_attention_rocm) +from vllm.platforms import current_platform + + +class ReferenceAttention: + + def __init__(self, Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, use_alibi, dtype, + input_metadata): + self.Z = Z + self.HQ = HQ + self.HK = HK + self.N_CTX_Q = N_CTX_Q + self.N_CTX_K = N_CTX_K + self.D_HEAD = D_HEAD + self.use_alibi = use_alibi + self.dtype = dtype + self.input_metadata = input_metadata + + def fwd(self, q, k, v): + scores = torch.einsum('bhqd,bhkd->bhqk', q, + k).float() * self.input_metadata.sm_scale + if self.input_metadata.causal: + mask = torch.tril(torch.ones(self.N_CTX_Q, + self.N_CTX_K, + device="cuda"), + diagonal=self.N_CTX_K - self.N_CTX_Q) + scores[:, :, mask == 0] = float("-inf") + + if self.input_metadata.bias is not None: + scores += self.input_metadata.bias + + if self.use_alibi: + scores += compute_alibi_tensor(self.input_metadata.alibi_slopes, + self.N_CTX_Q, self.N_CTX_K) + + p = torch.softmax(scores, dim=-1) + if self.input_metadata.causal: + # If N_CTX_Q > N_CTX_K, there's at least one row of all -infs going + # into softmax. This creates a row of NaNs as -inf - -inf == NaN. + # So we fix this by converting the NaNs to 0s, which is what they + # should be out of the softmax. + nan_mask = torch.isnan(p) + p[nan_mask == 1] = 0 + ref_out = torch.einsum('bhqk,bhkd->bhqd', p.to(self.dtype), v) + # compare + if self.input_metadata.layout == 'bshd': + ref_out = ref_out.transpose(1, 2).clone() + return ref_out + + def fwd_fp8(self, q_quantized, k_quantized, v_quantized): + q = (q_quantized.to(torch.float16) * self.input_metadata.q_descale).to( + self.dtype) + k = (k_quantized.to(torch.float16) * self.input_metadata.k_descale).to( + self.dtype) + v = (v_quantized.to(torch.float16) * self.input_metadata.v_descale).to( + self.dtype) + result = self.fwd(q, k, v) + if self.input_metadata.o_scale is not None: + result, _ = scale_fp8(result, self.input_metadata.o_scale) + return result + + def fwd_fp8_kv(self, q, k_quantized, v_quantized): + k_descale, v_descale = (self.input_metadata.k_descale, + self.input_metadata.v_descale) + k_dequantized = (k_quantized.to(torch.float32) * + k_descale.to(torch.float32)).to(self.dtype) + v_dequantized = (v_quantized.to(torch.float32) * + v_descale.to(torch.float32)).to(self.dtype) + return self.fwd(q, k_dequantized, v_dequantized) + + def varlen_fwd(self, q, k, v, is_mqa=False): + ref_out = torch.empty_like(q) + if is_mqa: + # Make KV look like HQ/HK "groups" of HK. Later, we will reshape so + # the size aligns with Q. + k_ref = k.view(k.shape[0], k.shape[1], 1, + k.shape[2]).expand(-1, -1, self.HQ // self.HK, -1) + v_ref = v.view(v.shape[0], v.shape[1], 1, + v.shape[2]).expand(-1, -1, self.HQ // self.HK, -1) + else: + k_ref = k + v_ref = v + + for i in range(0, self.input_metadata.num_contexts): + start_q, start_k = self.input_metadata.cu_seqlens_q[ + i], self.input_metadata.cu_seqlens_k[i] + end_q, end_k = self.input_metadata.cu_seqlens_q[ + i + 1], self.input_metadata.cu_seqlens_k[i + 1] + k_curr = k_ref[start_k:end_k] + v_curr = v_ref[start_k:end_k] + if is_mqa: + k_curr = k_curr.reshape(k_curr.shape[0], -1, k_curr.shape[3]) + v_curr = v_curr.reshape(v_curr.shape[0], -1, v_curr.shape[3]) + scores = torch.einsum('qhd,khd->qhk', q[start_q:end_q], + k_curr).float() + p = torch.softmax(scores * self.input_metadata.sm_scale, + dim=-1).half() + ref_out[start_q:end_q] = torch.einsum('qhk,khd->qhd', p, v_curr) + return ref_out + + +def quantize_input(q, k, v, fp8_kv=False, use_o_scale=False): + q_descale = None + if not fp8_kv: + q, q_descale = scale_fp8(q) + k, k_descale = scale_fp8(k) + v, v_descale = scale_fp8(v) + + # In real world use case, the p scale would be a parameter trained by the + # model. + p_scale = None + + o_scale = torch.rand(1, device="cuda", + requires_grad=False) if use_o_scale else None + + return q, k, v, q_descale, k_descale, v_descale, p_scale, o_scale + + +def input_helper( + Z, + HQ, + HK, + N_CTX_Q, + N_CTX_K, + D_HEAD, + dtype, + layout=None, + use_alibi=None, + causal=None, + is_fp8=False, + fp8_kv=False, + use_o_scale=False, + use_bias=False, +): + assert layout in SUPPORTED_LAYOUTS, "Got unsupported layout." + + current_platform.seed_everything(0) + + # Initialize q, k, v + if layout == 'bhsd': + q_tensor_shape = (Z, HQ, N_CTX_Q, D_HEAD) + k_tensor_shape = (Z, HK, N_CTX_K, D_HEAD) + elif layout == 'bshd': + q_tensor_shape = (Z, N_CTX_Q, HQ, D_HEAD) + k_tensor_shape = (Z, N_CTX_K, HK, D_HEAD) + + if use_alibi: + # for n heads the set of slopes is the geometric sequence that starts + # 2^(-8/n) + alibi_slopes = torch.tensor( + [2**(-8 / HQ * i) for i in range(1, HQ + 1)], + dtype=torch.float32, + device="cuda").repeat(Z, 1) + else: + alibi_slopes = None + + if use_bias: + bias = torch.randn((1, HQ, N_CTX_Q, N_CTX_K), + dtype=dtype, + device="cuda", + requires_grad=False) + else: + bias = None + + q = torch.randn(q_tensor_shape, + dtype=dtype, + device="cuda", + requires_grad=False) + k = torch.randn(k_tensor_shape, + dtype=dtype, + device="cuda", + requires_grad=False) + v = torch.randn(k_tensor_shape, + dtype=dtype, + device="cuda", + requires_grad=False) + + if is_fp8: + (q, k, v, q_descale, k_descale, v_descale, p_scale, + o_scale) = quantize_input(q, + k, + v, + use_o_scale=use_o_scale, + fp8_kv=fp8_kv) + else: + q_descale = k_descale = v_descale = p_scale = o_scale = None + + input_metadata = MetaData(sm_scale=D_HEAD**-0.5, + max_seqlens_q=N_CTX_Q, + max_seqlens_k=N_CTX_K, + layout=layout, + alibi_slopes=alibi_slopes, + alibi_batch=Z, + alibi_nheads=HQ, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + p_scale=p_scale, + o_scale=o_scale, + bias=bias, + seqlen_q=N_CTX_Q, + seqlen_k=N_CTX_K) + return q, k, v, input_metadata + + +def varlen_input_helper(Z, + HQ, + HK, + N_CTX_Q, + N_CTX_K, + D_HEAD, + dtype, + equal_seqlens=False): + current_platform.seed_everything(0) + + # Random sequence lengths. Using N_CTX as kind of max of sum of individual + # seqs + if not equal_seqlens: + max_seqlens_q = N_CTX_Q // Z + max_seqlens_k = N_CTX_K // Z + seqlens_q = torch.randint(1, + max_seqlens_q + 1, (Z, ), + dtype=torch.int32) + seqlens_k = torch.randint(1, + max_seqlens_k + 1, (Z, ), + dtype=torch.int32) + else: + seqlens_q = torch.full((Z, ), N_CTX_Q // Z) + seqlens_k = torch.full((Z, ), N_CTX_K // Z) + + # Calculate cumulative sequence lengths + cu_seqlens_q = torch.cat([ + torch.tensor([0], dtype=torch.int32), + seqlens_q.cumsum(dim=0, dtype=torch.int32) + ]) + cu_seqlens_k = torch.cat([ + torch.tensor([0], dtype=torch.int32), + seqlens_k.cumsum(dim=0, dtype=torch.int32) + ]) + cu_seqlens_q = cu_seqlens_q.to(device="cuda") + cu_seqlens_k = cu_seqlens_k.to(device="cuda") + + # Initialize q, k, v with variable lengths + total_q = cu_seqlens_q[-1].item() + total_k = cu_seqlens_k[-1].item() + q = torch.randn((total_q, HQ, D_HEAD), dtype=dtype, + device="cuda").normal_(mean=0., std=0.5).requires_grad_() + k = torch.randn((total_k, HK, D_HEAD), dtype=dtype, + device="cuda").normal_(mean=0., std=0.5).requires_grad_() + v = torch.randn((total_k, HK, D_HEAD), dtype=dtype, + device="cuda").normal_(mean=0., std=0.5).requires_grad_() + sm_scale = D_HEAD**-0.5 + input_metadata = MetaData(sm_scale=sm_scale) + input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) + return q, k, v, input_metadata + + +@pytest.mark.parametrize('Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD', [ + (1, 48, 12, 1, 1, 64), + (4, 4, 4, 128, 128, 65), + (16, 48, 48, 1, 1, 128), + (64, 48, 24, 3, 3, 128), + (4, 4, 4, 113, 123, 1), +]) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('use_alibi', [True, False]) +@pytest.mark.parametrize('layout', ['bshd']) +def test_op_fwd(Z, + HQ, + HK, + N_CTX_Q, + N_CTX_K, + D_HEAD, + causal, + use_alibi, + layout, + dtype=torch.float16): + current_platform.seed_everything(0) + q, k, v, input_metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, + dtype, layout, use_alibi, causal) + + o = torch.empty_like(q) + + # triton implementation + tri_out, _ = triton_attention_rocm(q, k, v, o, input_metadata) + + # Transpose here if layout is bshd so we have same reference code for all + # layouts + if layout == 'bshd': + q = q.transpose(1, 2).clone() + k = k.transpose(1, 2).clone() + v = v.transpose(1, 2).clone() + # Replicate K and V if using MQA/GQA + if HQ != HK: + k = k.view(k.shape[0], k.shape[1], -1, k.shape[2], + k.shape[3]).expand(-1, -1, HQ // HK, -1, + -1).reshape(k.shape[0], -1, k.shape[2], + k.shape[3]) + v = v.view(v.shape[0], v.shape[1], -1, v.shape[2], + v.shape[3]).expand(-1, -1, HQ // HK, -1, + -1).reshape(v.shape[0], -1, v.shape[2], + v.shape[3]) + + ref_impl = ReferenceAttention(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, + use_alibi, dtype, input_metadata) + ref_out = ref_impl.fwd(q, k, v) + + torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) + + +@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ + (4, 48, 1, 1, 64), + (4, 48, 1, 1, 128), + (4, 48, 3, 3, 128), + (4, 4, 128, 128, 65), +]) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('layout', ['bhsd']) +@pytest.mark.parametrize('use_o_scale', [True, False]) +@pytest.mark.skipif(torch.cuda.get_device_capability() < (9, 0), + reason="Triton FP8 requires CUDA 9.0 or higher") +def test_op_fwd_fp8(Z, + H, + N_CTX_Q, + N_CTX_K, + D_HEAD, + causal, + layout, + use_o_scale, + dtype=torch.float32): + current_platform.seed_everything(0) + + # Disable grad to save memory it won't run into OOM on CI machine. + # q, k, v, input_metadata = input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, + # dtype, layout) + + q_quantized, k_quantized, v_quantized, input_metadata = input_helper( + Z, + H, + H, + N_CTX_Q, + N_CTX_K, + D_HEAD, + dtype, + causal=causal, + layout=layout, + is_fp8=True, + use_o_scale=use_o_scale) + + o = torch.empty_like(q_quantized) if use_o_scale else None + + tri_out, _ = triton_attention_rocm(q_quantized, k_quantized, v_quantized, + o, input_metadata) + + ref_impl = ReferenceAttention(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, + dtype, input_metadata) + ref_out = ref_impl.fwd_fp8(q_quantized, k_quantized, v_quantized) + + # compare + torch.testing.assert_close(ref_out.to(torch.float32), + tri_out.to(torch.float32), + atol=7e-2, + rtol=2e-1) + + +@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ + (4, 48, 1, 1, 64), + (4, 48, 1, 1, 128), + (4, 48, 3, 3, 128), + (4, 4, 128, 128, 65), + (4, 4, 113, 123, 1), +]) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('layout', ['bhsd']) +def test_op_fwd_fp8_kv(Z, + H, + N_CTX_Q, + N_CTX_K, + D_HEAD, + causal, + layout, + dtype=torch.float32): + current_platform.seed_everything(0) + + q, k_quantized, v_quantized, input_metadata = input_helper(Z, + H, + H, + N_CTX_Q, + N_CTX_K, + D_HEAD, + dtype, + causal=causal, + layout=layout, + is_fp8=True, + fp8_kv=True) + + o = torch.empty_like(q) + + tri_out, _ = triton_attention_rocm(q, k_quantized, v_quantized, o, + input_metadata) + + ref_impl = ReferenceAttention(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, + dtype, input_metadata) + ref_out = ref_impl.fwd_fp8_kv(q, k_quantized, v_quantized) + + torch.testing.assert_close(ref_out, tri_out, atol=3e-2, rtol=8e-1) + + +@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ + (4, 48, 1, 1, 64), + (4, 48, 1, 1, 128), + (4, 48, 3, 3, 128), + (4, 4, 128, 128, 65), +]) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('use_bias', [True]) +@pytest.mark.parametrize('dtype', [torch.bfloat16]) +def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype): + current_platform.seed_everything(0) + q, k, v, input_metadata = input_helper(Z, + H, + H, + N_CTX_Q, + N_CTX_K, + D_HEAD, + dtype, + layout='bhsd', + causal=causal, + use_bias=use_bias) + o = torch.empty_like(q) + + # triton implementation + tri_out, _ = triton_attention_rocm(q, k, v, o, input_metadata) + + ref_impl = ReferenceAttention(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, False, + dtype, input_metadata) + ref_out = ref_impl.fwd(q, k, v) + + # compare + torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) + + +# NOTE: Uses thd layout, so also tests thd. +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(1, 48, 256, 64), + (4, 48, 512, 64), + (16, 48, 512, 64), + (64, 48, 128, 128)]) +@pytest.mark.parametrize('causal', [True, False]) +def test_op_varlen_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): + + q, k, v, input_metadata = varlen_input_helper(Z, H, H, N_CTX, N_CTX, + D_HEAD, dtype) + + tri_out = torch.empty_like(q) + triton_attention_rocm(q, k, v, tri_out, input_metadata) + + ref_impl = ReferenceAttention(Z, H, H, N_CTX, N_CTX, D_HEAD, False, dtype, + input_metadata) + ref_out = ref_impl.varlen_fwd(q, k, v, is_mqa=False) + + torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) + + +# NOTE: Uses thd layout, so also tests thd. +@pytest.mark.parametrize('Z, HQ, HK, N_CTX, D_HEAD', [(2, 48, 24, 128, 64), + (4, 48, 12, 256, 64), + (4, 48, 4, 512, 64), + (4, 64, 16, 128, 128)]) +@pytest.mark.parametrize('causal', [False]) +def test_op_varlen_mqa_fwd(Z, + HQ, + HK, + N_CTX, + D_HEAD, + causal, + dtype=torch.float16): + q, k, v, input_metadata = varlen_input_helper(Z, HQ, HK, N_CTX, N_CTX, + D_HEAD, dtype) + + tri_out = torch.empty_like(q) + triton_attention_rocm(q, k, v, tri_out, input_metadata) + + ref_impl = ReferenceAttention(Z, HQ, HK, N_CTX, N_CTX, D_HEAD, False, + dtype, input_metadata) + ref_out = ref_impl.varlen_fwd(q, k, v, is_mqa=True) + + torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index 745818eb6cfff..e98b5254541b6 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -1,31 +1,237 @@ -#!/usr/bin/env python # SPDX-License-Identifier: Apache-2.0 """ Fused Attention =============== -This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao -(https://tridao.me/publications/flash2/flash2.pdf) -Credits: OpenAI kernel team, AMD ML Frameworks Triton team +This is a Triton implementation of the Flash Attention v2 algorithm +See https://tridao.me/publications/flash2/flash2.pdf -Features supported: +Credits: +AMD Triton kernels team +OpenAI kernel team + +Currently only the forward kernel is supported, and contains these features: 1) Fwd with causal masking -2) Any sequence lengths without padding (currently fwd kernel only) -3) Support for different sequence lengths for q and k -4) Nested tensor API currently does not support dropout or bias. - -Not currently supported: - -1) Non power of two head dims +2) Arbitrary Q and KV sequence lengths +3) Arbitrary head sizes +4) Multi and grouped query attention +5) Variable sequence lengths +6) ALiBi and matrix bias +7) FP8 support """ +from typing import Optional + import torch import triton import triton.language as tl -torch_dtype: tl.constexpr = torch.float16 +from vllm import _custom_ops as ops +from vllm.platforms import current_platform + +SUPPORTED_LAYOUTS = ['thd', 'bhsd', 'bshd'] + +default_eight_bit_dtype_triton = tl.float8e4b8 +default_eight_bit_dtype_torch = current_platform.fp8_dtype() +default_float8_info = torch.finfo(default_eight_bit_dtype_torch) + +FP8_MIN = triton.language.constexpr(default_float8_info.min) + +# According to https://github.com/vllm-project/vllm/blob/main +# /csrc/quantization/utils.cuh#L31, +# need to make the max for the uz datatype be 224.0 for accuracy reasons. +FP8_MAX = triton.language.constexpr( + default_float8_info.max if default_eight_bit_dtype_torch != + torch.float8_e4m3fnuz else 224.0) + + +class MetaData: + cu_seqlens_q = None + cu_seqlens_k = None + max_seqlens_q = 0 + max_seqlens_k = 0 + bias = None + alibi_slopes = None + causal = False + num_contexts = 0 + varlen = False + eight_bit = False + layout = None + return_encoded_softmax = False + eight_bit_dtype_triton = default_eight_bit_dtype_triton + eight_bit_dtype_torch = default_eight_bit_dtype_torch + output_dtype = None + + # Note about layouts: + # + # thd - [num_tokens, num_heads, head_size] + # bshd - [batch_size, seq_len, num_heads, head_size] + # bhsd - [batch_size, num_heads, seq_len, head_size] + # + # This is for each tensor, all tensors must have same layout. + # Q can have num_heads and seq_len differ from from K and V, + # however K and V must agree on this. + # + # Notes about varlen and bias: + # Only one or the other is implemented, meaning can't combine + # both varlen and bias right now. + # + # Note about quantization: + # Only 8-bit quantization supported (for now) and specifically fp8. + # Scales must be tensors. + # o_scale: This is 'output scaling', but comes from parameter called + # 'input_scale', this is applied to the output from the kernel. + # o_scale should be None if none of the other quantization parameters + # are used. + # + # NOTE: Object is in a tentatively good state after initialized, however, + # to verify, call check_args(q,k,v,o) where o is the output tensor. + def __init__( + self, + sm_scale=1.0, + layout=None, # layout can be 'bshd', 'bhsd', or 'thd' + output_dtype=None, + max_seqlens_q=0, + max_seqlens_k=0, + # varlen params + cu_seqlens_q=None, # only 'thd' layout supported for varlen + cu_seqlens_k=None, + # quant params + q_descale=None, + k_descale=None, + v_descale=None, + p_scale=None, + o_scale=None, + # bias params + bias=None, # varlen not implemented for bias + seqlen_q=None, + seqlen_k=None, + # alibi params + alibi_slopes=None, + alibi_batch=None, + alibi_nheads=None, + # causal + causal=None, + ): + self.sm_scale = sm_scale + self.output_dtype = output_dtype + self.max_seqlens_q = max_seqlens_q + self.max_seqlens_k = max_seqlens_k + self.layout = layout + if cu_seqlens_q is not None or cu_seqlens_k is not None: + assert cu_seqlens_q is not None and cu_seqlens_k is not None + assert layout is None or layout not in [ + 'bshd', 'bhsd' + ], "Varlen only implemented for thd layout" + self.set_varlen_params(cu_seqlens_q, cu_seqlens_k) + quant_params = [q_descale, k_descale, v_descale, p_scale, o_scale] + if any(x is not None for x in quant_params): + p_descale = 1.0 / p_scale if p_scale is not None else None + self.set_eight_bit_params(q_descale, k_descale, v_descale, p_scale, + p_descale, o_scale) + if bias is not None: + self.need_bias(bias, seqlen_q, seqlen_k) + if alibi_slopes is not None: + self.need_alibi(alibi_slopes, alibi_batch, alibi_nheads) + if causal is not None and causal: + self.need_causal() + + def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k): + self.varlen = True + self.layout = 'thd' + self.cu_seqlens_q = cu_seqlens_q + self.cu_seqlens_k = cu_seqlens_k + # Without "varlen", there should still be one sequence. + assert len(cu_seqlens_q) >= 2 + assert len(cu_seqlens_q) == len(cu_seqlens_k) + self.num_contexts = len(cu_seqlens_q) - 1 + for i in range(0, self.num_contexts): + self.max_seqlens_q = max( + cu_seqlens_q[i + 1].item() - cu_seqlens_q[i].item(), + self.max_seqlens_q) + self.max_seqlens_k = max( + cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item(), + self.max_seqlens_k) + + def set_eight_bit_params(self, q_descale, k_descale, v_descale, p_scale, + p_descale, o_scale): + self.eight_bit = True + self.q_descale = q_descale + self.k_descale = k_descale + self.v_descale = v_descale + self.p_scale = p_scale + self.p_descale = p_descale + self.o_scale = o_scale + self.use_p_scale = (p_scale is not None) and ( + p_descale is not None) and (v_descale is not None) + self.eight_bit_kv = ((q_descale is None) and (k_descale is not None) + and (v_descale is not None)) + self.eight_bit_dtype_torch = default_eight_bit_dtype_torch + + def need_bias(self, bias, seqlen_q, seqlen_k): + assert bias is not None + assert bias.is_cuda + assert bias.dim() == 4 + assert bias.shape[0] == 1 + assert bias.shape[2:] == (seqlen_q, seqlen_k) + self.bias = bias + + def need_alibi(self, alibi_slopes, batch, nheads): + assert alibi_slopes.is_cuda + assert alibi_slopes.dim() == 2 + assert alibi_slopes.shape[0] == batch + assert alibi_slopes.shape[1] == nheads + self.alibi_slopes = alibi_slopes + + def need_causal(self): + self.causal = True + + def check_args(self, q, k, v, o): + assert q.dim() == k.dim() and q.dim() == v.dim() + + batch, nheads_q, nheads_k, head_size = get_shape_from_layout( + q, k, self) + if self.varlen: + assert q.dim() == 3 + assert self.cu_seqlens_q is not None + assert self.cu_seqlens_k is not None + assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k) + # TODO: Remove once bias is supported with varlen + assert self.bias is None + assert not self.return_encoded_softmax + else: + assert q.dim() == 4 + assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0 + assert self.cu_seqlens_q is None and self.cu_seqlens_k is None + assert k.shape == v.shape + assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] + # TODO: Change assert if we support qkl f8 and v f16 + if self.eight_bit: + if self.eight_bit_kv: + assert (v.dtype == k.dtype + and k.dtype == self.eight_bit_dtype_torch) + assert q.dtype != k.dtype + assert (self.v_descale is not None) and (self.k_descale + is not None) + else: + assert (q.dtype == k.dtype and q.dtype == v.dtype + and q.dtype == self.eight_bit_dtype_torch) + assert (self.q_descale + is not None) and (self.k_descale + is not None) and (self.v_descale + is not None) + if self.use_p_scale: + assert (self.p_scale is not None) and (self.p_descale + is not None) + else: + assert (q.dtype == k.dtype) and (q.dtype == v.dtype) + assert head_size <= 256 + assert o.shape == q.shape + assert (nheads_q % nheads_k) == 0 + assert self.layout is not None + assert self.layout == 'thd' or not self.varlen @triton.jit @@ -38,100 +244,155 @@ def max_fn(x, y): return tl.math.max(x, y) +# Convenience function to load with optional boundary checks. +# "First" is the major dim, "second" is the minor dim. @triton.jit -def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): - ms = tl.arange(0, m) - ns = tl.arange(0, n) - return philox_offset + ms[:, None] * stride + ns[None, :] - - -@triton.jit -def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): - rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, - stride).to(tl.uint32) - # TODO: use tl.randint for better performance - return tl.rand(philox_seed, rng_offsets) - - -@triton.jit -def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): - rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, - stride) - rng_keep = rng_output > dropout_p - return rng_keep - - -@triton.jit -def load_fn(block_ptr, first, second, pad): - if first and second: - tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) - elif first: - tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad) - elif second: - tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad) +def masked_load(ptrs, offset_first, offset_second, boundary_first, + boundary_second): + if offset_first is not None and offset_second is not None: + mask = (offset_first[:, None] < boundary_first) & \ + (offset_second[None, :] < boundary_second) + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_first is not None: + mask = offset_first[:, None] < boundary_first + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_second is not None: + mask = offset_second[None, :] < boundary_second + tensor = tl.load(ptrs, mask=mask, other=0.0) else: - tensor = tl.load(block_ptr) + tensor = tl.load(ptrs) return tensor +@triton.jit +def compute_alibi_block(alibi_slope, + seqlen_q, + seqlen_k, + offs_m, + offs_n, + transpose=False): + # when seqlen_k and seqlen_q are different we want the diagonal to stick to + # the bottom right of the attention matrix + # for casual mask we want something like this where (1 is kept and 0 is + # masked) + # seqlen_q = 2 and seqlen_k = 5 + # 1 1 1 1 0 + # 1 1 1 1 1 + # seqlen_q = 5 and seqlen_k = 2 + # 0 0 + # 0 0 + # 0 0 + # 1 0 + # 1 1 + # for alibi the diagonal is 0 indicating no penalty for attending to that + # spot and increasing penalty for attending further from the diagonal + # e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, + # offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False + # 1. offs_m[:,None] = [[0], + # [1], + # 2. offs_m[:,None] + seqlen_k = [[5], + # [6], + # 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3], + # [4], + # 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = + # [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1], [4], [ 4, 3, 2, 1, 0]] + # 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1], + # [ -4, -3, -2, -1, 0]], + relative_pos_block = (offs_m[:, None] + seqlen_k - seqlen_q - + offs_n[None, :]) + alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) + if transpose: + return alibi_block.T + else: + return alibi_block + + +def compute_alibi_tensor(alibi_slopes, seqlen_q, seqlen_k): + q_idx = torch.arange(seqlen_q, dtype=torch.int32, + device="cuda").unsqueeze(-1) # (N_CTX_Q, 1) + k_idx = torch.arange(seqlen_k, dtype=torch.int32, + device="cuda").unsqueeze(0) # (1, N_CTX_K) + relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q - + k_idx) # (N_CTX_Q, N_CTX_K) + return -1 * alibi_slopes.unsqueeze(-1).unsqueeze( + -1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K) + + +@triton.jit +def quant_fp8(x, scale): + x *= scale + x = tl.clamp(x, FP8_MIN, FP8_MAX) + return x + + @triton.jit def _attn_fwd_inner( acc, l_i, m_i, q, - K_block_ptr, - V_block_ptr, + k_ptrs, + v_ptrs, + bias_ptrs, + stride_kn, + stride_vk, + stride_bn, start_m, actual_seqlen_k, - dropout_p, + actual_seqlen_q, philox_seed, batch_philox_offset, - encoded_softmax_block_ptr, + encoded_sm_ptrs, block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, - bias_ptr, + alibi_slope, + q_descale, + k_descale, + v_descale, + p_scale, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, - PRE_LOAD_V: tl.constexpr, - MASK_STEPS: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - RETURN_ENCODED_SOFTMAX: tl.constexpr, - PADDED_HEAD: tl.constexpr, + SHOULD_PRE_LOAD_V: tl.constexpr, + SHOULD_MASK_STEPS: tl.constexpr, + SHOULD_RETURN_ENCODED_SOFTMAX: tl.constexpr, + USE_PADDED_HEAD: tl.constexpr, + IS_ACTUAL_BLOCK_DMODEL: tl.constexpr, + QK_SCALE: tl.constexpr, + IS_EIGHT_BIT_GEMM: tl.constexpr, + USE_P_SCALE: tl.constexpr, + IS_EIGHT_BIT_KV: tl.constexpr, + QUANT_DTYPE: tl.constexpr = default_eight_bit_dtype_triton, ): + # loop over k, v, and update accumulator for start_n in range(block_min, block_max, BLOCK_N): # For padded blocks, we will overrun the tensor size if # we load all BLOCK_N. For others, the blocks are all within range. - k = load_fn( - K_block_ptr, - PADDED_HEAD, - MASK_STEPS and (n_extra_tokens != 0), - "zero", - ) - if PRE_LOAD_V: - v = load_fn( - V_block_ptr, - MASK_STEPS and (n_extra_tokens != 0), - PADDED_HEAD, - "zero", - ) + k_offs_n = start_n + tl.arange(0, + BLOCK_N) if SHOULD_MASK_STEPS else None + k_offs_k = None if not USE_PADDED_HEAD else tl.arange(0, BLOCK_DMODEL) + k = masked_load(k_ptrs, k_offs_k, k_offs_n, IS_ACTUAL_BLOCK_DMODEL, + actual_seqlen_k) + if SHOULD_PRE_LOAD_V: + # We can use the same offsets as k, just with dims transposed. + v = masked_load(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, + IS_ACTUAL_BLOCK_DMODEL) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # We start from end of seqlen_k so only the first iteration would need # to be checked for padding if it is not a multiple of block_n # TODO: This can be optimized to only be true for the padded block. - if MASK_STEPS: # noqa: SIM102 + if SHOULD_MASK_STEPS: # noqa: SIM102 # If this is the last block / iteration, we want to # mask if the sequence length is not a multiple of block size - # a solution is to always do BLOCK_M // BLOCK_N + 1 steps - # if not is_modulo_mn. last step might get wasted but that is okay. + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not + # is_modulo_mn. last step might get wasted but that is okay. # check if this masking works for that case. if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): boundary_m = tl.full([BLOCK_M], @@ -144,167 +405,276 @@ def _attn_fwd_inner( causal_boundary = start_n + offs_n_causal causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] qk = tl.where(causal_mask, qk, float("-inf")) + # -- compute qk ---- - qk += tl.dot(q, k) - if bias_ptr is not None: - bias = load_fn(bias_ptr, False, MASK_STEPS - and (n_extra_tokens != 0), "zero") - # While bias is added after multiplying qk with sm_scale, our - # optimization to use 2^x instead of e^x results in an additional - # scale factor of log2(e) which we must also multiply the bias with. - qk += bias * 1.44269504089 + if IS_EIGHT_BIT_GEMM: + qk += ((((tl.dot(q, k).to(tl.float32) * q_descale)) * k_descale) * + QK_SCALE) + else: + if IS_EIGHT_BIT_KV: + k = (k * k_descale).to(q.type.element_ty) + qk += (tl.dot(q, k) * QK_SCALE) + + if bias_ptrs is not None: + bias_offs_n = start_n + tl.arange( + 0, BLOCK_N) if SHOULD_MASK_STEPS else None + bias = masked_load(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q, + actual_seqlen_k) + # While bias is added after multiplying qk with sm_scale, + # our optimization to use 2^x instead of e^x results in an + # additional scale factor of log2(e) which we must also multiply + # the bias with. + qk += (bias * 1.44269504089) + + if alibi_slope is not None: + # Compute the global position of each token within the sequence + global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + global_n_positions = start_n + tl.arange(0, BLOCK_N) + alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q, + actual_seqlen_k, + global_m_positions, + global_n_positions) + qk += (alibi_block * 1.44269504089) # scale factor of log2(e) + + # softmax m_ij = tl.maximum(m_i, tl.max(qk, 1)) qk = qk - m_ij[:, None] p = tl.math.exp2(qk) # CAVEAT: Must update l_ij before applying dropout l_ij = tl.sum(p, 1) - if ENABLE_DROPOUT: - philox_offset = (batch_philox_offset + - start_m * BLOCK_M * actual_seqlen_k + start_n - - BLOCK_N) - keep = dropout_mask( - philox_seed, - philox_offset, - dropout_p, - BLOCK_M, - BLOCK_N, - actual_seqlen_k, - ) - if RETURN_ENCODED_SOFTMAX: - tl.store( - encoded_softmax_block_ptr, - tl.where(keep, p, - -p).to(encoded_softmax_block_ptr.type.element_ty), - ) - p = tl.where(keep, p, 0.0) - elif RETURN_ENCODED_SOFTMAX: - tl.store( - encoded_softmax_block_ptr, - p.to(encoded_softmax_block_ptr.type.element_ty), - ) + if SHOULD_RETURN_ENCODED_SOFTMAX: + tl.store(encoded_sm_ptrs, p.to(encoded_sm_ptrs.type.element_ty)) # -- update output accumulator -- alpha = tl.math.exp2(m_i - m_ij) acc = acc * alpha[:, None] - if not PRE_LOAD_V: - v = load_fn( - V_block_ptr, - MASK_STEPS and (n_extra_tokens != 0), - PADDED_HEAD, - "zero", - ) + if not SHOULD_PRE_LOAD_V: + v = masked_load(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, + IS_ACTUAL_BLOCK_DMODEL) # -- update m_i and l_i l_i = l_i * alpha + l_ij # update m_i and l_i m_i = m_ij - acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - if bias_ptr is not None: - bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) - if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, - (0, BLOCK_N)) + + if IS_EIGHT_BIT_GEMM: + if USE_P_SCALE: + p = quant_fp8(p, p_scale).to(QUANT_DTYPE) + acc += tl.dot(p, v) + else: + # v is in eight_bit but p is not, we want the gemm in p's type + acc += tl.dot(p, v.to(p.type.element_ty)) + else: + if IS_EIGHT_BIT_KV: + v = (v * v_descale).to(p.type.element_ty) + acc += tl.dot(p.to(v.type.element_ty), v) + + k_ptrs += BLOCK_N * stride_kn + v_ptrs += BLOCK_N * stride_vk + if bias_ptrs is not None: + bias_ptrs += BLOCK_N * stride_bn + if SHOULD_RETURN_ENCODED_SOFTMAX: + encoded_sm_ptrs += BLOCK_N return acc, l_i, m_i +def get_cdna_autotune_configs(): + return [ + triton.Config( + { + 'BLOCK_M': 128, + 'BLOCK_N': 128, + 'waves_per_eu': 2, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 + }, + num_stages=1, + num_warps=4), + triton.Config( + { + 'BLOCK_M': 128, + 'BLOCK_N': 64, + 'waves_per_eu': 2, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 + }, + num_stages=1, + num_warps=4), + triton.Config( + { + 'BLOCK_M': 128, + 'BLOCK_N': 64, + 'waves_per_eu': 3, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 + }, + num_stages=1, + num_warps=4), + triton.Config( + { + 'BLOCK_M': 128, + 'BLOCK_N': 64, + 'waves_per_eu': 1, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 + }, + num_stages=1, + num_warps=4), + triton.Config( + { + 'BLOCK_M': 128, + 'BLOCK_N': 32, + 'waves_per_eu': 2, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 + }, + num_stages=1, + num_warps=4), + ], [ + 'IS_CAUSAL', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', + 'IS_ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK' + ] + + +def get_rdna_autotune_configs(): + return [ + triton.Config( + { + 'BLOCK_M': 32, + 'BLOCK_N': 32, + 'waves_per_eu': 4, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 + }, + num_stages=1, + num_warps=2), + triton.Config( + { + 'BLOCK_M': 32, + 'BLOCK_N': 32, + 'waves_per_eu': 2, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 + }, + num_stages=1, + num_warps=2), + triton.Config( + { + 'BLOCK_M': 32, + 'BLOCK_N': 16, + 'waves_per_eu': 4, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 + }, + num_stages=1, + num_warps=2), + triton.Config( + { + 'BLOCK_M': 32, + 'BLOCK_N': 16, + 'waves_per_eu': 2, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 + }, + num_stages=1, + num_warps=2), + triton.Config( + { + 'BLOCK_M': 16, + 'BLOCK_N': 16, + 'waves_per_eu': 4, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 + }, + num_stages=1, + num_warps=2), + triton.Config( + { + 'BLOCK_M': 16, + 'BLOCK_N': 16, + 'waves_per_eu': 2, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 + }, + num_stages=1, + num_warps=2), + # Fall-back config. + triton.Config( + { + 'BLOCK_M': 16, + 'BLOCK_N': 16, + 'waves_per_eu': 1, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 + }, + num_stages=1, + num_warps=2), + ], [ + 'IS_CAUSAL', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', + 'IS_ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK' + ] + + +def get_general_autotune_configs(): + return [ + triton.Config( + { + 'BLOCK_M': 128, + 'BLOCK_N': 128, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 + }, + num_stages=1, + num_warps=4), + triton.Config( + { + 'BLOCK_M': 128, + 'BLOCK_N': 64, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 + }, + num_stages=1, + num_warps=4), + triton.Config( + { + 'BLOCK_M': 128, + 'BLOCK_N': 32, + 'SHOULD_PRE_LOAD_V': False, + 'GRID_CU_MULTIP': 2 + }, + num_stages=1, + num_warps=4), + ], [ + 'IS_CAUSAL', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', + 'IS_ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK' + ] + + +def has_cdna_target(): + ROCM_CDNA_TARGETS = ["gfx940", "gfx941", "gfx942", "gfx90a", "gfx908"] + return triton.runtime.driver.active.get_current_target( + ).arch in ROCM_CDNA_TARGETS + + +def is_rocm_cdna(): + return current_platform.is_rocm() and has_cdna_target() + + +def get_autotune_configs(): + if is_rocm_cdna(): + return get_cdna_autotune_configs() + elif current_platform.is_rocm(): + return get_rdna_autotune_configs() + else: + return get_general_autotune_configs() + + +autotune_configs, autotune_keys = get_autotune_configs() + + @triton.autotune( - configs=[ - triton.Config( - { - "BLOCK_M": 256, - "BLOCK_N": 64, - "waves_per_eu": 2, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=8, - ), - triton.Config( - { - "BLOCK_M": 128, - "BLOCK_N": 128, - "waves_per_eu": 2, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=4, - ), - triton.Config( - { - "BLOCK_M": 256, - "BLOCK_N": 128, - "waves_per_eu": 2, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=8, - ), - triton.Config( - { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 1, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=4, - ), - triton.Config( - { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 3, - "PRE_LOAD_V": True, - }, - num_stages=1, - num_warps=4, - ), - triton.Config( - { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 3, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=4, - ), - triton.Config( - { - "BLOCK_M": 64, - "BLOCK_N": 64, - "waves_per_eu": 4, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=8, - ), - triton.Config( - { - "BLOCK_M": 32, - "BLOCK_N": 32, - "waves_per_eu": 4, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=8, - ), - # TODO: This config fails with head_size not pow2 with data mismatches. - # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, - # 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), - triton.Config( - { - "BLOCK_M": 16, - "BLOCK_N": 16, - "waves_per_eu": 1, - "PRE_LOAD_V": False, - }, - num_stages=1, - num_warps=4, - ), - ], - key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], + configs=autotune_configs, + key=autotune_keys, + use_cuda_graph=True, ) @triton.jit def attn_fwd( @@ -312,38 +682,53 @@ def attn_fwd( K, V, bias, - sm_scale, + SM_SCALE: tl.constexpr, L, Out, - stride_qz, - stride_qh, - stride_qm, - stride_qk, - stride_kz, - stride_kh, - stride_kn, - stride_kk, - stride_vz, - stride_vh, - stride_vk, - stride_vn, - stride_oz, - stride_oh, - stride_om, - stride_on, - stride_bz, - stride_bh, - stride_bm, - stride_bn, + stride_qz: tl.int64, + stride_qh: tl.int64, + stride_qm: tl.int64, + stride_qk: tl.int64, + stride_kz: tl.int64, + stride_kh: tl.int64, + stride_kn: tl.int64, + stride_kk: tl.int64, + stride_vz: tl.int64, + stride_vh: tl.int64, + stride_vk: tl.int64, + stride_vn: tl.int64, + stride_oz: tl.int64, + stride_oh: tl.int64, + stride_om: tl.int64, + stride_on: tl.int64, + stride_bz: tl.int64, + stride_bh: tl.int64, + stride_bm: tl.int64, + stride_bn: tl.int64, + stride_az: tl.int64, + stride_ah: tl.int64, + q_descale_ptr, + k_descale_ptr, + p_scale_ptr, + p_descale_ptr, + o_descale_ptr, + v_descale_ptr, + q_descale_has_singleton: tl.constexpr, + k_descale_has_singleton: tl.constexpr, + p_descale_has_singleton: tl.constexpr, + v_descale_has_singleton: tl.constexpr, cu_seqlens_q, cu_seqlens_k, - dropout_p, philox_seed, + NUM_CU: tl.constexpr, + GRID_CU_MULTIP: tl.constexpr, + B: tl.constexpr, philox_offset_base, encoded_softmax, + alibi_slopes, HQ: tl.constexpr, HK: tl.constexpr, - ACTUAL_BLOCK_DMODEL: tl.constexpr, + IS_ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, @@ -351,24 +736,39 @@ def attn_fwd( BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, - PRE_LOAD_V: tl.constexpr, - BIAS_TYPE: tl.constexpr, - ENABLE_DROPOUT: tl.constexpr, - RETURN_ENCODED_SOFTMAX: tl.constexpr, + SHOULD_PRE_LOAD_V: tl.constexpr, + USE_BIAS: tl.constexpr, + SHOULD_RETURN_ENCODED_SOFTMAX: tl.constexpr, + USE_ALIBI: tl.constexpr, + IS_EIGHT_BIT: tl.constexpr, + USE_P_SCALE: tl.constexpr, + IS_EIGHT_BIT_KV: tl.constexpr, + QUANT_DTYPE: tl.constexpr = default_eight_bit_dtype_triton, ): - start_m = tl.program_id(0) - off_h_q = tl.program_id(1) - off_z = tl.program_id(2) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) + + if o_descale_ptr is not None: + o_descale = tl.load(o_descale_ptr) + + start_m: tl.int64 = tl.program_id(0) + off_h_q: tl.int64 = tl.program_id(1) + off_z: tl.int64 = tl.program_id(2) + + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M).to(tl.int64) + offs_n = tl.arange(0, BLOCK_N).to(tl.int64) + offs_d = tl.arange(0, BLOCK_DMODEL).to(tl.int64) + + # as we can't have return statements inside while loop in Triton + continue_condition = True + if VARLEN: cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start - # We have a one-size-fits-all grid in id(0). Some seqlens might be too - # small for all start_m so for those we return early. + # We have a one-size-fits-all grid in id(0). Some seqlens might be + # too small for all start_m so for those we return early. if start_m * BLOCK_M > seqlen_q: - return + continue_condition = False + # return cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start @@ -378,444 +778,598 @@ def attn_fwd( seqlen_q = MAX_SEQLENS_Q seqlen_k = MAX_SEQLENS_K - # Now we compute whether we need to exit early due to causal masking. - # This is because for seqlen_q > seqlen_k, M rows of the attn scores - # are completely masked, resulting in 0s written to the output, and - # inf written to LSE. We don't need to do any GEMMs in this case. - # This block of code determines what N is, and if this WG is operating - # on those M rows. - n_blocks = cdiv_fn(seqlen_k, BLOCK_N) - if IS_CAUSAL: - # If seqlen_q == seqlen_k, the attn scores are a square matrix. - # If seqlen_q != seqlen_k, attn scores are rectangular which means - # the causal mask boundary is bottom right aligned, and ends at either - # the top edge (seqlen_q < seqlen_k) or left edge. - # This captures the decrease in n_blocks if we have a rectangular attn - # matrix - n_blocks_seqlen = cdiv_fn( - (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) - # This is what adjusts the block_max for the current WG, only - # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks - n_blocks = min(n_blocks, n_blocks_seqlen) - # If we have no blocks after adjusting for seqlen deltas, this WG is - # part of the blocks that are all 0. We exit early. - if n_blocks <= 0: - o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + - off_h_q * stride_oh) - O_block_ptr = tl.make_block_ptr( - base=Out + o_offset, - shape=(seqlen_q, BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) - # We still need to write 0s to the result - # tl.store(O_block_ptr, - # acc.to(Out.type.element_ty), boundary_check=(0,1)) - # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q - # + offs_m - # We store inf to LSE, not -inf because in the bwd pass, - # we subtract this - # from qk which makes it -inf, such that exp(qk - inf) = 0 - # for these masked blocks. - # l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) - # tl.store(l_ptrs, l) - # TODO: Should dropout and return encoded softmax be handled here? - return + if continue_condition: + # Now we compute whether we need to exit early due to causal + # masking. This is because for seqlen_q > seqlen_k, M rows of the + # attn scores are completely masked, resulting in 0s written to the + # output, and inf written to LSE. We don't need to do any GEMMs in + # this case. This block of code determines what N is, and if this + # WG is operating on those M rows. + n_blocks = cdiv_fn(seqlen_k, BLOCK_N) + if (IS_CAUSAL): + # If seqlen_q == seqlen_k, the attn scores are a square matrix. + # If seqlen_q != seqlen_k, attn scores are rectangular which + # means the causal mask boundary is bottom right aligned, and + # ends at either the top edge (seqlen_q < seqlen_k) or left + # edge. This captures the decrease in n_blocks if we have a + # rectangular attn matrix + n_blocks_seqlen = cdiv_fn( + (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) + # This is what adjusts the block_max for the current WG, only + # if IS_CAUSAL. Otherwise we want to always iterate through all + # n_blocks + n_blocks = min(n_blocks, n_blocks_seqlen) + # If we have no blocks after adjusting for seqlen deltas, this + # WG is part of the blocks that are all 0. We exit early. + if n_blocks <= 0: + o_offset = (Out + off_z * stride_oz + off_h_q * stride_oh + + cu_seqlens_q_start * stride_om) + o_ptrs = (o_offset + offs_m[:, None] * stride_om + + offs_d[None, :] * stride_on) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + o_ptrs_mask = (offs_m[:, None] < seqlen_q).broadcast_to( + [BLOCK_M, BLOCK_DMODEL]) + # We still need to write 0s to the result + tl.store(o_ptrs, acc, mask=o_ptrs_mask) + # The tensor allocated for L is based on MAX_SEQLENS_Q as + # that is statically known. + l_ptrs = (L + off_z * HQ * MAX_SEQLENS_Q + + off_h_q * MAX_SEQLENS_Q + offs_m) + # We store inf to LSE, not -inf because in the bwd pass, + # we subtract this from qk which makes it -inf, such that + # exp(qk - inf) = 0 for these masked blocks. + l_value = tl.full([BLOCK_M], + value=float("inf"), + dtype=tl.float32) + l_ptrs_mask = offs_m < MAX_SEQLENS_Q + tl.store(l_ptrs, l_value, mask=l_ptrs_mask) + # TODO: Should dropout and return encoded softmax be + # handled here too? + continue_condition = False + # return - # If MQA / GQA, set the K and V head offsets appropriately. - GROUP_SIZE: tl.constexpr = HQ // HK - off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q + if continue_condition: + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE: tl.constexpr = HQ // HK + off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q + n_extra_tokens = 0 + if seqlen_k < BLOCK_N: + n_extra_tokens = BLOCK_N - seqlen_k + elif seqlen_k % BLOCK_N: + n_extra_tokens = seqlen_k % BLOCK_N + USE_PADDED_HEAD: tl.constexpr = (IS_ACTUAL_BLOCK_DMODEL + != BLOCK_DMODEL) - n_extra_tokens = 0 - if seqlen_k < BLOCK_N: - n_extra_tokens = BLOCK_N - seqlen_k - elif seqlen_k % BLOCK_N: - n_extra_tokens = seqlen_k % BLOCK_N - padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL + # Compute pointers for all the tensors used in this kernel. + q_offset = (Q + off_z * stride_qz + off_h_q * stride_qh + + cu_seqlens_q_start * stride_qm) + q_ptrs = (q_offset + offs_m[:, None] * stride_qm + + offs_d[None, :] * stride_qk) + k_offset = (K + off_z * stride_kz + off_h_k * stride_kh + + cu_seqlens_k_start * stride_kn) + k_ptrs = (k_offset + offs_d[:, None] * stride_kk + + offs_n[None, :] * stride_kn) + v_offset = (V + off_z * stride_vz + off_h_k * stride_vh + + cu_seqlens_k_start * stride_vk) + v_ptrs = (v_offset + offs_n[:, None] * stride_vk + + offs_d[None, :] * stride_vn) + # Compute pointers for all scale tensors used in this kernel. - # Compute pointers for all the tensors used in this kernel. - q_offset = (off_z * stride_qz + off_h_q * stride_qh + - cu_seqlens_q_start * stride_qm) - Q_block_ptr = tl.make_block_ptr( - base=Q + q_offset, - shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - k_offset = (off_z * stride_kz + off_h_k * stride_kh + - cu_seqlens_k_start * stride_kn) - K_block_ptr = tl.make_block_ptr( - base=K + k_offset, - shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), - strides=(stride_kk, stride_kn), - offsets=(0, 0), - block_shape=(BLOCK_DMODEL, BLOCK_N), - order=(0, 1), - ) - v_offset = (off_z * stride_vz + off_h_k * stride_vh + - cu_seqlens_k_start * stride_vk) - V_block_ptr = tl.make_block_ptr( - base=V + v_offset, - shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), - strides=(stride_vk, stride_vn), - offsets=(0, 0), - block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(1, 0), - ) - if BIAS_TYPE != 0: - bias_ptr = tl.make_block_ptr( - base=bias + off_h_q * stride_bh, - shape=(seqlen_q, seqlen_k), - strides=(stride_bm, stride_bn), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_N), - order=(1, 0), - ) - else: - bias_ptr = None - if ENABLE_DROPOUT: - batch_philox_offset = philox_offset_base \ - + (off_z * HQ + off_h_q) \ - * seqlen_q * seqlen_k - else: - batch_philox_offset = 0 - # We can ask to return the dropout mask without actually doing any dropout. - # In this case, we return an invalid pointer so indicate the mask is not i - # valid. - # TODO: Fix encoded softmax. It currently uses just h_q in the base offset. - if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.make_block_ptr( - base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, - shape=(seqlen_q, seqlen_k), - strides=(seqlen_k, 1), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_N), - order=(1, 0), - ) - else: - encoded_softmax_block_ptr = 0 - # initialize pointer to m and l - m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) - l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - # scale sm_scale by log_2(e) and use 2^x in the loop as we do not - # have native e^x support in HW. - qk_scale = sm_scale * 1.44269504089 - # Q is loaded once at the beginning and shared by all N blocks. - q = load_fn(Q_block_ptr, True, padded_head, "zero") - q = (q * qk_scale).to(Q_block_ptr.type.element_ty) + IS_EIGHT_BIT_GEMM: tl.constexpr = IS_EIGHT_BIT & ( + not IS_EIGHT_BIT_KV) + if IS_EIGHT_BIT: + if k_descale_has_singleton: + k_descale_ptrs = k_descale_ptr + else: + k_descale_ptrs = k_descale_ptr + off_h_k - # Here we compute how many full and masked blocks we have. - padded_block_k = n_extra_tokens != 0 - is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) - if IS_CAUSAL: - # There are always at least BLOCK_M // BLOCK_N masked blocks. - # Additionally there might be one more due to dissimilar seqlens. - masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) - else: - # Padding on Q does not need to be masked in the FA loop. - masked_blocks = padded_block_k - # if IS_CAUSAL, not is_modulo_mn does not always result in an additional - # block. In this case we might exceed n_blocks so pick the min. - masked_blocks = min(masked_blocks, n_blocks) - n_full_blocks = n_blocks - masked_blocks - block_min = 0 - block_max = n_blocks * BLOCK_N - # Compute for full blocks. Here we set causal to false regardless of its - # value because there is no masking. Similarly we do not need padding. - if n_full_blocks > 0: - block_max = (n_blocks - masked_blocks) * BLOCK_N - acc, l_i, m_i = _attn_fwd_inner( - acc, - l_i, - m_i, - q, - K_block_ptr, - V_block_ptr, - start_m, - seqlen_k, - dropout_p, - philox_seed, - batch_philox_offset, - encoded_softmax_block_ptr, - # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ - block_min, - block_max, - 0, - 0, - 0, - bias_ptr, - # IS_CAUSAL, .... - False, - BLOCK_M, - BLOCK_DMODEL, - BLOCK_N, - offs_m, - offs_n, - # _, MASK_STEPS, ... - PRE_LOAD_V, - False, - ENABLE_DROPOUT, - RETURN_ENCODED_SOFTMAX, - padded_head, - ) - block_min = block_max - block_max = n_blocks * BLOCK_N + if v_descale_has_singleton: + v_descale_ptrs = v_descale_ptr + else: + v_descale_ptrs = v_descale_ptr + off_h_k - tl.debug_barrier() - # Remaining blocks, if any, are full / not masked. - if masked_blocks > 0: - offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0 - K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N)) - V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0)) - if bias_ptr is not None: - bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) - if RETURN_ENCODED_SOFTMAX: - encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, - (0, n_full_blocks)) - acc, l_i, m_i = _attn_fwd_inner( - acc, - l_i, - m_i, - q, - K_block_ptr, - V_block_ptr, - start_m, - seqlen_k, - dropout_p, - philox_seed, - batch_philox_offset, - encoded_softmax_block_ptr, - block_min, - block_max, - offs_n_causal, - masked_blocks, - n_extra_tokens, - bias_ptr, - IS_CAUSAL, - BLOCK_M, - BLOCK_DMODEL, - BLOCK_N, - offs_m, - offs_n, - # _, MASK_STEPS, ... - PRE_LOAD_V, - True, - ENABLE_DROPOUT, - RETURN_ENCODED_SOFTMAX, - padded_head, - ) - # epilogue - acc = acc / l_i[:, None] - if ENABLE_DROPOUT: - acc = acc / (1 - dropout_p) - # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, - # then we have one block with a row of all NaNs which come from computing - # softmax over a row of all -infs (-inf - inf = NaN). We check for that here - # and store 0s where there are NaNs as these rows should've been zeroed out. - end_m_idx = (start_m + 1) * BLOCK_M - start_m_idx = start_m * BLOCK_M - causal_start_idx = seqlen_q - seqlen_k - acc = acc.to(Out.type.element_ty) - if IS_CAUSAL: # noqa: SIM102 - if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: - out_mask_boundary = tl.full((BLOCK_DMODEL, ), - causal_start_idx, - dtype=tl.int32) - mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) - out_ptrs_mask = (mask_m_offsets[:, None] - >= out_mask_boundary[None, :]) - z = 0.0 - acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) - # write back LSE - # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m - # If seqlen_q not multiple of BLOCK_M, we need to mask out the last - # few rows. This is only true for the last M block. For others, - # overflow_size will be -ve - # overflow_size = end_m_idx - seqlen_q - # if overflow_size > 0: - # boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) - # # This is a > check because mask being 0 blocks the store. - # l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) - # tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) - # else: - # tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + if not IS_EIGHT_BIT_KV: + if q_descale_has_singleton: + q_descale_ptrs = q_descale_ptr + else: + q_descale_ptrs = q_descale_ptr + off_h_q + if USE_P_SCALE: + if p_descale_has_singleton: + p_scale_ptrs = p_scale_ptr + p_descale_ptrs = p_descale_ptr + else: + p_scale_ptrs = p_scale_ptr + off_h_q + p_descale_ptrs = p_descale_ptr + off_h_q - # write back O - o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + - off_h_q * stride_oh) - O_block_ptr = tl.make_block_ptr( - base=Out + o_offset, - shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0), - ) - # Need boundary check on this to make sure the padding from the - # Q and KV tensors in both dims are not part of what we store back. - # TODO: Do the boundary check optionally. - tl.store(O_block_ptr, acc, boundary_check=(0, 1)) + if USE_BIAS: + bias_offset = off_h_q * stride_bh + bias_ptrs = (bias + bias_offset + offs_m[:, None] * stride_bm + + offs_n[None, :] * stride_bn) + else: + bias_ptrs = None + + if USE_ALIBI: + a_offset = off_z * stride_az + off_h_q * stride_ah + alibi_slope = tl.load(alibi_slopes + a_offset) + else: + alibi_slope = None + + batch_philox_offset = 0 + # We can ask to return the dropout mask without doing any + # dropout. In this case, we return an invalid pointer so + # indicate the mask is not valid. + if SHOULD_RETURN_ENCODED_SOFTMAX: + encoded_sm_base = (encoded_softmax + + off_h_q * seqlen_q * seqlen_k) + encoded_sm_ptrs = (encoded_sm_base + + offs_m[:, None] * seqlen_k + + offs_n[None, :]) + else: + encoded_sm_ptrs = None + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use 2^x in the loop as we do + # not have native e^x support in HW. + QK_SCALE: tl.constexpr = SM_SCALE * 1.44269504089 + # Q is loaded once at the beginning and shared by all N blocks. + q_ptrs_mask = offs_m[:, None] < seqlen_q + if USE_PADDED_HEAD: + q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] + < IS_ACTUAL_BLOCK_DMODEL) + q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) + + if IS_EIGHT_BIT: + k_descale = tl.load(k_descale_ptrs) + v_descale = tl.load(v_descale_ptrs) + q_descale = None if IS_EIGHT_BIT_KV else tl.load( + q_descale_ptrs) + if USE_P_SCALE: + p_scale = tl.load(p_scale_ptrs) + p_descale = tl.load(p_descale_ptrs) + else: + p_scale = None + p_descale = None + else: + q_descale = None + k_descale = None + v_descale = None + p_scale = None + p_descale = None + # Here we compute how many full and masked blocks we have. + padded_block_k = n_extra_tokens != 0 + is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) + if IS_CAUSAL: + # There are always at least BLOCK_M // BLOCK_N masked + # blocks. Additionally there might be one more due to + # dissimilar seqlens. + masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) + else: + # Padding on Q does not need to be masked in the FA loop. + masked_blocks = padded_block_k + # if IS_CAUSAL, not is_modulo_mn does not always result in an + # additional block. In this case we might exceed n_blocks so + # pick the min. + masked_blocks = min(masked_blocks, n_blocks) + n_full_blocks = n_blocks - masked_blocks + block_min = 0 + block_max = n_blocks * BLOCK_N + # Compute for full blocks. Here we set causal to false + # regardless of its actual value because there is no masking. + # Similarly we do not need padding. + if n_full_blocks > 0: + block_max = (n_blocks - masked_blocks) * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + k_ptrs, + v_ptrs, + bias_ptrs, + stride_kn, + stride_vk, + stride_bn, + start_m, + seqlen_k, + seqlen_q, + philox_seed, + batch_philox_offset, + encoded_sm_ptrs, + # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ + block_min, + block_max, + 0, + 0, + 0, + alibi_slope, + q_descale, + k_descale, + v_descale, + p_scale, + # IS_CAUSAL, .... + False, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, + # _, SHOULD_MASK_STEPS, ... + SHOULD_PRE_LOAD_V, + False, + SHOULD_RETURN_ENCODED_SOFTMAX, + USE_PADDED_HEAD, + IS_ACTUAL_BLOCK_DMODEL, + QK_SCALE, + IS_EIGHT_BIT_GEMM, + USE_P_SCALE, + IS_EIGHT_BIT_KV, + QUANT_DTYPE) + block_min = block_max + block_max = n_blocks * BLOCK_N + + tl.debug_barrier() + # Remaining blocks, if any, are full / not masked. + if (masked_blocks > 0): + if IS_CAUSAL: + offs_n_causal = offs_n + (seqlen_q - seqlen_k) + else: + offs_n_causal = 0 + k_ptrs += n_full_blocks * BLOCK_N * stride_kn + v_ptrs += n_full_blocks * BLOCK_N * stride_vk + if USE_BIAS: + bias_ptrs += n_full_blocks * BLOCK_N * stride_bn + if SHOULD_RETURN_ENCODED_SOFTMAX: + encoded_sm_ptrs += n_full_blocks * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + k_ptrs, + v_ptrs, + bias_ptrs, + stride_kn, + stride_vk, + stride_bn, + start_m, + seqlen_k, + seqlen_q, + philox_seed, + batch_philox_offset, + encoded_sm_ptrs, + block_min, + block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + alibi_slope, + q_descale, + k_descale, + v_descale, + p_scale, + IS_CAUSAL, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, + # _, SHOULD_MASK_STEPS, ... + SHOULD_PRE_LOAD_V, + True, + SHOULD_RETURN_ENCODED_SOFTMAX, + USE_PADDED_HEAD, + IS_ACTUAL_BLOCK_DMODEL, + QK_SCALE, + IS_EIGHT_BIT_GEMM, + USE_P_SCALE, + IS_EIGHT_BIT_KV, + QUANT_DTYPE) + + if IS_EIGHT_BIT and not IS_EIGHT_BIT_KV: + if USE_P_SCALE: + acc *= p_descale + acc *= v_descale + + # epilogue + # This helps the compiler do Newton Raphson on l_i vs on acc + # which is much larger. + l_recip = 1 / l_i[:, None] + acc = acc * l_recip + + # If seqlen_q > seqlen_k but the delta is not a multiple of + # BLOCK_M, then we have one block with a row of all NaNs which + # come from computing softmax over a row of all + # -infs (-inf - inf = NaN). We check for that here and store 0s + # where there are NaNs as these rows should've been zeroed out. + end_m_idx = (start_m + 1) * BLOCK_M + start_m_idx = start_m * BLOCK_M + causal_start_idx = seqlen_q - seqlen_k + if IS_EIGHT_BIT and not IS_EIGHT_BIT_KV: # noqa: SIM102 + if o_descale_ptr is not None: + acc = quant_fp8(acc, o_descale) + + acc = acc.to(Out.type.element_ty) + if IS_CAUSAL: # noqa: SIM102 + if (causal_start_idx > start_m_idx + and causal_start_idx < end_m_idx): + out_mask_boundary = tl.full((BLOCK_DMODEL, ), + causal_start_idx, + dtype=tl.int32) + mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) + out_ptrs_mask = (mask_m_offsets[:, None] + >= out_mask_boundary[None, :]) + z = tl.zeros((1, ), tl.float32) + acc = tl.where(out_ptrs_mask, acc, + z.to(acc.type.element_ty)) + # write back LSE + l_ptrs = (L + off_z * HQ * MAX_SEQLENS_Q + + off_h_q * MAX_SEQLENS_Q + offs_m) + # If seqlen_q not multiple of BLOCK_M, we need to mask out the + # last few rows. This is only true for the last M block. + # For others, overflow_size will be -ve + overflow_size = end_m_idx - seqlen_q + if overflow_size > 0: + boundary = tl.full((BLOCK_M, ), + BLOCK_M - overflow_size, + dtype=tl.int32) + l_ptrs_mask = tl.arange(0, BLOCK_M) < boundary + tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) + else: + tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + + # write back O + o_offset = (Out + off_z * stride_oz + off_h_q * stride_oh + + cu_seqlens_q_start * stride_om) + o_ptrs = (o_offset + offs_m[:, None] * stride_om + + offs_d[None, :] * stride_on) + o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL], 1, dtype=tl.int1) + if overflow_size > 0: + o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q) + if USE_PADDED_HEAD: + o_ptrs_mask = o_ptrs_mask & (offs_d[None, :] + < IS_ACTUAL_BLOCK_DMODEL) + tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask) -def check_args( - q, - k, - v, - o, - varlen=True, - max_seqlens=None, - cu_seqlens_q=None, - cu_seqlens_k=None, -): - assert q.dim() == k.dim() and q.dim() == v.dim() - if varlen: - assert q.dim() == 3 - total_q, nheads_q, head_size = q.shape - total_k, nheads_k, _ = k.shape - assert cu_seqlens_q is not None - assert cu_seqlens_k is not None - assert len(cu_seqlens_q) == len(cu_seqlens_k) - else: - assert q.dim() == 4 - batch, nheads_q, seqlen_q, head_size = q.shape - _, nheads_k, seqlen_k, _ = k.shape - assert max_seqlens > 0 - assert k.shape == v.shape - assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] - # TODO: Change assert if we support qkl f8 and v f16 - assert q.dtype == k.dtype and q.dtype == v.dtype - assert head_size <= 256 - assert o.shape == q.shape - assert (nheads_q % nheads_k) == 0 +def get_shape_from_layout(q, k, metadata): + assert metadata.layout in SUPPORTED_LAYOUTS, "Got unsupported layout." + + if metadata.layout == 'thd': + nheads_q, nheads_k = q.shape[1], k.shape[1] + head_size = q.shape[-1] + batch = metadata.num_contexts + elif metadata.layout == 'bhsd': + batch, nheads_q, _, head_size = q.shape + nheads_k = k.shape[1] + elif metadata.layout == 'bshd': + batch, _, nheads_q, head_size = q.shape + nheads_k = k.shape[2] + return batch, nheads_q, nheads_k, head_size + + +def get_strides_from_layout(q, k, v, o, metadata): + assert metadata.layout in SUPPORTED_LAYOUTS, "Got unsupported layout." + + STRIDE_PERMUTATIONS = { + 'thd': (None, 1, 0, 2), + 'bhsd': (0, 1, 2, 3), + 'bshd': (0, 2, 1, 3), + } + + perm = STRIDE_PERMUTATIONS[metadata.layout] + stride = lambda x, p: (0 if p is None else x.stride(p)) + strides = lambda x: (stride(x, p) for p in perm) + + return tuple(strides(x) for x in [q, k, v, o]) class _attention(torch.autograd.Function): @staticmethod - def forward( - ctx, - q, - k, - v, - o, - cu_seqlens_q, - cu_seqlens_k, - max_seqlens_q, - max_seqlens_k, - causal=False, - sm_scale=1.0, - bias=None, - ): - if o is None: - o = torch.empty_like(q, dtype=v.dtype) + def forward(ctx, q, k, v, o, metadata: MetaData): + # NOTE: a large bias tensor leads to overflow during pointer arithmetic + if (metadata.bias is not None): + assert (metadata.bias.numel() < 2**31) - check_args( - q, - k, - v, - o, - varlen=True, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - ) - if True: # varlen - total_q, nheads_q, head_size = q.shape - total_k, nheads_k, _ = k.shape - batch = len(cu_seqlens_q) - 1 - q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) - k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) - v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) - o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) - else: - batch, seqlen_q, nheads_q, head_size = q.shape - _, seqlen_k, nheads_k, _ = k.shape - q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) - k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) - v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) - o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) + if o is None: + if metadata.eight_bit: + o = torch.empty_like( + q, + dtype=metadata.output_dtype if metadata.output_dtype + is not None else metadata.eight_bit_dtype_torch) + else: + o = torch.empty_like(q, dtype=q.dtype) + + metadata.check_args(q, k, v, o) + + batch, nheads_q, nheads_k, head_size = get_shape_from_layout( + q, k, metadata) + q_strides, k_strides, v_strides, o_strides = get_strides_from_layout( + q, k, v, o, metadata) # Get closest power of 2 over or equal to 32. - unpadded_head_dims = {32, 64, 128, 256} - if head_size not in unpadded_head_dims: - padded_d_model = None - for i in unpadded_head_dims: - if i > head_size: - padded_d_model = i - break - assert padded_d_model is not None + padded_d_model = 1 << (head_size - 1).bit_length() + # Smallest head_dim supported is 16. If smaller, the tile in the + # kernel is padded - there is no padding in memory for any dims. + padded_d_model = max(padded_d_model, 16) + + # encoded_softmax is used to validate dropout behavior vs the + # PyTorch SDPA math backend reference. We zero this out to give a + # consistent starting point and then populate it with the output of + # softmax with the sign bit set according to the dropout mask. + # The resulting return allows this mask to be fed into the reference + # implementation for testing only. This return holds no useful output + # aside from debugging. + if metadata.return_encoded_softmax: + encoded_softmax = torch.zeros( + (q.shape[0], q.shape[1], q.shape[2], k.shape[2]), + device=q.device, + dtype=torch.float32) else: - padded_d_model = head_size + encoded_softmax = None - grid = lambda META: ( - triton.cdiv(max_seqlens_q, META["BLOCK_M"]), - nheads_q, - batch, - ) - - encoded_softmax = None + M = torch.empty((batch, nheads_q, metadata.max_seqlens_q), + device=q.device, + dtype=torch.float32) # Seed the RNG so we get reproducible results for testing. philox_seed = 0x1BF52 philox_offset = 0x1D4B42 - if bias is not None: - bias_strides = ( - bias.stride(0), - bias.stride(1), - bias.stride(2), - bias.stride(3), - ) + if metadata.bias is not None: + bias_strides = (metadata.bias.stride(0), metadata.bias.stride(1), + metadata.bias.stride(2), metadata.bias.stride(3)) else: bias_strides = (0, 0, 0, 0) + if metadata.alibi_slopes is not None: + alibi_strides = (metadata.alibi_slopes.stride(0), + metadata.alibi_slopes.stride(1)) + else: + alibi_strides = (0, 0) + + if metadata.eight_bit: + q_descale, k_descale, p_scale, p_descale, v_descale, o_scale = ( + metadata.q_descale, metadata.k_descale, metadata.p_scale, + metadata.p_descale, metadata.v_descale, metadata.o_scale) + o_descale = 1.0 / o_scale if o_scale is not None else None + else: + q_descale = k_descale = p_scale = None + p_descale = v_descale = o_descale = None + + # number of compute units available + NUM_CU = torch.cuda.get_device_properties("cuda").multi_processor_count + + grid = lambda META: (triton.cdiv(metadata.max_seqlens_q, META[ + 'BLOCK_M']), nheads_q, batch) + attn_fwd[grid]( q, k, v, - bias, - sm_scale, - None, + metadata.bias, + metadata.sm_scale, + M, o, *q_strides, *k_strides, *v_strides, *o_strides, *bias_strides, - cu_seqlens_q, - cu_seqlens_k, - dropout_p=0.0, + *alibi_strides, + q_descale, + k_descale, + p_scale, + p_descale, + o_descale, + v_descale, + q_descale.numel() == 1 if q_descale is not None else False, + k_descale.numel() == 1 if k_descale is not None else False, + p_descale.numel() == 1 if p_descale is not None else False, + v_descale.numel() == 1 if v_descale is not None else False, + metadata.cu_seqlens_q, + metadata.cu_seqlens_k, philox_seed=philox_seed, philox_offset_base=philox_offset, encoded_softmax=encoded_softmax, + alibi_slopes=metadata.alibi_slopes, HQ=nheads_q, HK=nheads_k, - ACTUAL_BLOCK_DMODEL=head_size, - MAX_SEQLENS_Q=max_seqlens_q, - MAX_SEQLENS_K=max_seqlens_k, - IS_CAUSAL=causal, - VARLEN=True, + IS_ACTUAL_BLOCK_DMODEL=head_size, + MAX_SEQLENS_Q=metadata.max_seqlens_q, + MAX_SEQLENS_K=metadata.max_seqlens_k, + IS_CAUSAL=metadata.causal, + VARLEN=metadata.varlen, BLOCK_DMODEL=padded_d_model, - BIAS_TYPE=0 if bias is None else 1, - ENABLE_DROPOUT=False, - RETURN_ENCODED_SOFTMAX=False, - ) + USE_BIAS=metadata.bias is not None, + USE_ALIBI=metadata.alibi_slopes is not None, + SHOULD_RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax, + IS_EIGHT_BIT=metadata.eight_bit, + USE_P_SCALE=metadata.eight_bit and metadata.use_p_scale, + IS_EIGHT_BIT_KV=metadata.eight_bit and metadata.eight_bit_kv, + NUM_CU=NUM_CU, + B=batch, + QUANT_DTYPE=metadata.eight_bit_dtype_triton) ctx.grid = grid - ctx.sm_scale = sm_scale + ctx.sm_scale = metadata.sm_scale ctx.BLOCK_DMODEL = head_size - ctx.causal = causal - ctx.dropout_p = 0.0 + ctx.causal = metadata.causal + ctx.alibi_slopes = metadata.alibi_slopes ctx.philox_seed = philox_seed ctx.philox_offset = philox_offset ctx.encoded_softmax = encoded_softmax - ctx.return_encoded_softmax = False + ctx.return_encoded_softmax = metadata.return_encoded_softmax return o, encoded_softmax -triton_attention = _attention.apply +triton_attention_rocm = _attention.apply + + +def scale_fp8(t, scale=None): + t_scaled, scale_out = ops.scaled_fp8_quant(t.reshape(-1, t.shape[-1]), + scale) + return t_scaled.reshape(t.shape), scale_out + + +def maybe_quantize_fp8(t, scale): + eight_bit_dtype = current_platform.fp8_dtype() + if t.dtype != eight_bit_dtype: + t, _ = scale_fp8(t, scale) + return t + + +def check_and_maybe_quantize_qkv(q, k, v, fp8_scales): + (q_scale, k_scale, v_scale, p_scale) = fp8_scales + + q = maybe_quantize_fp8(q, q_scale) + k = maybe_quantize_fp8(k, k_scale) + v = maybe_quantize_fp8(v, v_scale) + + return q, k, v + + +# query - [num_tokens, num_heads, head_size] +# key - [num_tokens, num_kv_heads, head_size] +# value - [num_tokens, num_kv_heads, head_size +# output - [num_tokens, num_heads, head_size] +def triton_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlens_q: int, + max_seqlens_k: int, + causal: bool = False, + sm_scale: float = 1.0, + bias: Optional[torch.Tensor] = None, + fp8_scales: Optional[tuple[float, ...]] = None, + input_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + if fp8_scales is not None: + q_descale, k_descale, v_descale, p_scale = fp8_scales + else: + q_descale = k_descale = v_descale = p_scale = None + + attn_metadata = MetaData(sm_scale=sm_scale, + max_seqlens_q=max_seqlens_q, + max_seqlens_k=max_seqlens_k, + causal=causal, + bias=bias, + output_dtype=q.dtype, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + p_scale=p_scale, + o_scale=input_scale) + + if fp8_scales is not None: + q, k, v = check_and_maybe_quantize_qkv(q, k, v, fp8_scales) + + return triton_attention_rocm(q, k, v, o, attn_metadata)