# 融合注意力(Fused Attention) 本节实现了一个基于 **Triton** 的 **Flash Attention v2 风格的融合注意力前向传播内核**,适用于昇腾(Ascend)NPU 平台。该实现支持: - **因果(causal)与非因果注意力** - **分块计算(tiling)以处理长序列** - **数值稳定性优化(max-shifted softmax)** 整体结构包含两个核心 Triton 内核: 1. `_attn_fwd_inner`:执行单个 query block 与 key/value blocks 的 attention 计算(分阶段处理 causal mask) 2. `_attn_fwd`:调度所有 query blocks,并管理 block 指针、accumulator 和归一化 并通过 PyTorch `autograd.Function` 封装为可调用的 `attention` 函数,与 `torch_npu.npu_fusion_attention` 进行精度对齐验证。 ```Python import pytest import torch import torch_npu import triton import triton.language as tl DEVICE = "npu" @triton.jit def _attn_fwd_inner(acc_ptr, l_i, m_i, q, # Accumulator, local l, local m, query vector K_block_ptr, V_block_ptr, # Key and value block pointers for current stage start_m, qk_scale, # Starting position of current query block, qk scale factor BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, # Block size constants STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, # Current stage flag, m and n offset indices N_CTX: tl.constexpr, fp8_v: tl.constexpr): # Total context length, whether to enable FP8 for value precision # Set the processing range [lo, hi) for the current stage (in column block units) # causal = true # stage = 1 # Causal attention, as the name implies, restricts the flow of information during computation, # only allowing the model to see the current and previous positions. # In other words, the output at the current position can only depend on the input at or before this position, # and cannot access information from future positions. # Causal attention ensures sequential order and prevents "leakage of future information." # But the following logic will also be triggered if STAGE == 1: # Stage 1: process all tokens before the query block tl.static_assert(BLOCK_M >= BLOCK_N) lo, hi = 0, start_m * BLOCK_M elif STAGE == 2: # Stage 2: process the current query block tl.static_assert(BLOCK_M >= BLOCK_N) lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M lo = tl.multiple_of(lo, BLOCK_M) # Align starting position # causal = False (no need for masking) else: lo, hi = 0, N_CTX # Process the entire context # Adjust K and V block pointers to the starting position `lo` K_block_ptr = tl.advance(K_block_ptr, (lo, 0)) # K is [HEAD_DIM, N_CTX], shift along the second dim by lo V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) # V is [N_CTX, HEAD_DIM], shift along the first dim by lo # Index mapping for the accumulator , used for slicing when HEAD_DIM >= 256 row = tl.arange(0, BLOCK_M)[:, None] col_head_dim = tl.arange(0, HEAD_DIM)[None, :] block2d_acc = row * HEAD_DIM + col_head_dim # Iterate over all k, v blocks in the current stage and accumulate the output for start_n in range(lo, hi, BLOCK_N): # Process BLOCK_N columns at a time start_n = tl.multiple_of(start_n, BLOCK_N) # Align column start position # -- Compute qk ---- k = tl.load(K_block_ptr) # Modify K trans_k = tl.trans(k) qk = tl.dot(q, trans_k) # Apply causal mask for STAGE 2 if STAGE == 2: mask = offs_m[:, None] >= (start_n + offs_n[None, :]) # Construct upper triangular mask qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) # Set invalid positions to -∞ m_ij = tl.maximum(m_i, tl.max(qk, 1)) # Update m_ij = max(m_i, max(qk)) qk -= m_ij[:, None] # Subtract max for softmax stability else: qk = qk * qk_scale m_ij = tl.maximum(m_i, tl.max(qk, 1)) # Scaled max qk = qk - m_ij[:, None] # Stabilize # Softmax weights p = exp(qk) p = tl.math.exp(qk) # Convert softmax weight type depending on FP8 usage if fp8_v: p_cast = p.to(tl.float8e5) # Convert to FP8 format (save memory) else: p_cast = p.to(k.dtype) v = tl.load(V_block_ptr) # Load corresponding V block pv = tl.dot(p_cast, v) l_ij = tl.sum(p, 1) # Softmax denominator (sum of each row) # -- Update m_i and l_i alpha = tl.math.exp(m_i - m_ij) # Update factor: exp difference between old and new max l_i = l_i * alpha + l_ij # Update softmax denominator # -- Update output accumulator -- if HEAD_DIM < 256: acc_ptr = acc_ptr * alpha[:, None] acc_ptr = tl.dot(p_cast, v, acc_ptr) else: # 1. Load current slice of accumulator acc = tl.load(acc_ptr + block2d_acc) # 2. Update in slices (split by 1/4 of BLOCK_M to avoid ub overflow) for i in range(4): # Calculate start/end rows for current slice offset = i * (BLOCK_M // 4) # Extract slice data acc_i = tl.extract_slice(acc, (offset, 0), (BLOCK_M // 4, HEAD_DIM), (1, 1)) alpha_i = tl.extract_slice(alpha, [offset], [BLOCK_M // 4], [1]) pv_i = tl.extract_slice(pv, (offset, 0), (BLOCK_M // 4, HEAD_DIM), (1, 1)) # Incrementally update slice: acc = acc * alpha + pv acc_i = acc_i * alpha_i[:, None] + pv_i # Write updated slice back to accumulator acc = tl.insert_slice(acc, acc_i, (offset, 0), (BLOCK_M // 4, HEAD_DIM), (1, 1)) # 3. updated accumulator tl.store(acc_ptr + block2d_acc, acc) m_i = m_ij # Update current block max # Advance V and K block pointers to next BLOCK_N range V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) K_block_ptr = tl.advance(K_block_ptr, (BLOCK_N, 0)) # Return accumulated output acc_ptr, softmax denominator l_i, and max value m_i return acc_ptr, l_i, m_i @triton.jit def _attn_fwd(Q, K, V, M, Out, acc, sm_scale, stride_qz: tl.constexpr, stride_qh: tl.constexpr, stride_qm: tl.constexpr, stride_qk: tl.constexpr, stride_kz: tl.constexpr, stride_kh: tl.constexpr, stride_kn: tl.constexpr, stride_kk: tl.constexpr, stride_vz: tl.constexpr, stride_vh: tl.constexpr, stride_vn: tl.constexpr, stride_vk: tl.constexpr, stride_oz: tl.constexpr, stride_oh: tl.constexpr, stride_om: tl.constexpr, stride_on: tl.constexpr, Z: tl.constexpr, H: tl.constexpr, N_CTX: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, STAGE: tl.constexpr ): # Total number of blocks in sequence dimension (M) NUM_BLOCKS_M = N_CTX // BLOCK_M # Total tasks = number of sequence blocks × batch size (Z) × number of attention heads (H) NUM_BLOCKS = NUM_BLOCKS_M * Z * H # Current M-dimension block index pid = tl.program_id(0) for block_idx in range(pid, NUM_BLOCKS, 20): task_hz_idx = block_idx // NUM_BLOCKS_M task_m_idx = block_idx % NUM_BLOCKS_M off_z = task_hz_idx // H off_h = task_hz_idx % H qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh # Create block pointers for Q, K, V, Output Q_block_ptr = tl.make_block_ptr( base=Q + qvk_offset, shape=(N_CTX, HEAD_DIM), strides=(stride_qm, stride_qk), offsets=(task_m_idx * BLOCK_M, 0), block_shape=(BLOCK_M, HEAD_DIM), order=(1, 0), ) V_block_ptr = tl.make_block_ptr( base=V + qvk_offset, shape=(N_CTX, HEAD_DIM), strides=(stride_vn, stride_vk), offsets=(0, 0), block_shape=(BLOCK_N, HEAD_DIM), order=(1, 0), ) K_block_ptr = tl.make_block_ptr( base=K + qvk_offset, shape=(N_CTX, HEAD_DIM), strides=(stride_kn, stride_kk), offsets=(0, 0), block_shape=(BLOCK_N, HEAD_DIM), order=(1, 0), ) O_block_ptr = tl.make_block_ptr( base=Out + qvk_offset, shape=(N_CTX, HEAD_DIM), strides=(stride_om, stride_on), offsets=(task_m_idx * BLOCK_M, 0), block_shape=(BLOCK_M, HEAD_DIM), order=(1, 0), ) # Initialize offsets offs_m = task_m_idx * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 # Initialize accumulator if HEAD_DIM < 256: acc_ptr = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) else: acc_offset = ( off_z.to(tl.int64) * stride_qz // stride_qm * HEAD_DIM + off_h.to(tl.int64) * stride_qh // stride_qm * HEAD_DIM + task_m_idx * BLOCK_M * HEAD_DIM ) acc_ptr = acc + acc_offset q = tl.load(Q_block_ptr) # stage 1: off-band # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE if STAGE & 1: acc_ptr, l_i, m_i = _attn_fwd_inner(acc_ptr, l_i, m_i, q, K_block_ptr, V_block_ptr, # task_m_idx, sm_scale, # BLOCK_M, HEAD_DIM, BLOCK_N, # 4 - STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 # ) # stage 2: on-band if STAGE & 2: # barrier makes it easier for compiler to schedule the # two loops independently acc_ptr, l_i, m_i = _attn_fwd_inner(acc_ptr, l_i, m_i, q, K_block_ptr, V_block_ptr, # task_m_idx, sm_scale, # BLOCK_M, HEAD_DIM, BLOCK_N, # 2, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 # ) m_i += tl.math.log(l_i) if HEAD_DIM < 256: accumulator = acc_ptr / l_i[:, None] else: row = tl.arange(0, BLOCK_M)[:, None] col_head_dim = tl.arange(0, HEAD_DIM)[None, :] block2d_acc = row * HEAD_DIM + col_head_dim accumulator = tl.load(acc_ptr + block2d_acc) accumulator = accumulator / l_i[:, None] m_ptrs = M + task_hz_idx * N_CTX + offs_m tl.store(m_ptrs, m_i) tl.store(O_block_ptr, accumulator.to(Out.type.element_ty)) class _attention(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, causal, sm_scale, BM, BN): """ Forward computation interface: Args: ctx: Context object q: Query tensor (Q), shape [Z, H, N_CTX, HEAD_DIM] k: Key tensor (K), shape [Z, H, N_CTX, HEAD_DIM] v: Value tensor (V), shape [Z, H, N_CTX, HEAD_DIM] causal: Whether to enable causal attention sm_scale: Scaling factor for QK product BM: Q block size (BLOCK_M) BN: K/V block size (BLOCK_N) Returns: o: Attention output tensor, shape [Z, H, N_CTX, HEAD_DIM] """ # shape constraints HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] # when v is in float8_e5m2 it is transposed. HEAD_DIM_V = v.shape[-1] assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V assert HEAD_DIM_K in {16, 32, 64, 128, 256} o = torch.empty_like(q) stage = 3 if causal else 1 extra_kern_args = {} # Number of NPU cores (adjust based on hardware) num_cores = 20 acc = torch.zeros((q.shape[0], q.shape[1], q.shape[2], HEAD_DIM_K), dtype=torch.float32, device=q.device) M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) _attn_fwd[(num_cores,)]( q, k, v, M, o, acc, sm_scale, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), v.stride(0), v.stride(1), v.stride(2), v.stride(3), o.stride(0), o.stride(1), o.stride(2), o.stride(3), q.shape[0], q.shape[1], N_CTX=q.shape[2], HEAD_DIM=HEAD_DIM_K, BLOCK_M=BM, BLOCK_N=BN, STAGE=stage, **extra_kern_args) ctx.save_for_backward(q, k, v, o, M) ctx.sm_scale = sm_scale ctx.HEAD_DIM = HEAD_DIM_K ctx.causal = causal return o attention = _attention.apply @pytest.mark.parametrize("Z, H, N_CTX, HEAD_DIM, causal, dtype, BM, BN", [ (1, 1, 128, 128, False, torch.float16, 32, 128), (1, 1, 128, 128, False, torch.bfloat16, 64, 128), (1, 2, 256, 256, False, torch.bfloat16, 32, 256), (2, 2, 128, 256, False, torch.float16, 64, 128), (4, 32, 64, 64, False, torch.float16, 32, 64), (4, 32, 1024, 64, False, torch.bfloat16, 64, 128), (4, 32, 4096, 64, False, torch.float16, 128, 128), ]) def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype, BM, BN): # Filter out non-integer cases; N_CTX must be divisible by BM and BN, and HEAD_DIM must be divisible by 16. if N_CTX % BM != 0 or N_CTX % BN != 0 or HEAD_DIM % 16 != 0: pytest.skip("Skipping non-divisible case") torch.manual_seed(20) q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) sm_scale = 0.5 tri_out = attention(q, k, v, causal, sm_scale, BM, BN) ref_out = torch_npu.npu_fusion_attention( q, k, v, H, padding_mask=None, atten_mask=None, scale=sm_scale, keep_prob=1.0, input_layout="BNSD", pre_tockens=65535, next_tockens=65535, sparse_mode=0, )[0] torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2, equal_nan=True) print(f"[PASSED] Attention shape:({Z}, {H}, {N_CTX}, {HEAD_DIM}), BM: {BM}, BN: {BN}, dtype: {dtype}") if __name__ == "__main__": test_op(1, 1, 128, 128, causal=False, dtype=torch.float16, BM=32, BN=128) test_op(1, 1, 128, 128, causal=False, dtype=torch.bfloat16, BM=64, BN=128) test_op(1, 2, 256, 256, causal=False, dtype=torch.bfloat16, BM=32, BN=256) test_op(2, 2, 128, 256, causal=False, dtype=torch.float16, BM=64, BN=128) test_op(4, 32, 64, 64, causal=False, dtype=torch.float16, BM=32, BN=64) test_op(4, 32, 1024, 64, causal=False, dtype=torch.bfloat16, BM=64, BN=128) test_op(4, 32, 4096, 64, causal=False, dtype=torch.float16, BM=128, BN=128) ``` Out: ```bash [PASSED] Attention shape:(1, 1, 128, 128), BM: 32, BN: 128, dtype: torch.float16 [PASSED] Attention shape:(1, 1, 128, 128), BM: 64, BN: 128, dtype: torch.bfloat16 [PASSED] Attention shape:(1, 2, 256, 256), BM: 32, BN: 256, dtype: torch.bfloat16 [PASSED] Attention shape:(2, 2, 128, 256), BM: 64, BN: 128, dtype: torch.float16 [PASSED] Attention shape:(4, 32, 64, 64), BM: 32, BN: 64, dtype: torch.float16 [PASSED] Attention shape:(4, 32, 1024, 64), BM: 64, BN: 128, dtype: torch.bfloat16 [PASSED] Attention shape:(4, 32, 4096, 64), BM: 128, BN: 128, dtype: torch.float16 ``` 上面输出日志表明Triton和PyTorch上的输出结果完全一致。