矩阵乘法 (Matrix Multiplication)

在本节中,我们展示了使用 Triton 进行矩阵乘法的内核实现。

计算内核

以下 Triton 内核实现了一个带偏置项的批量矩阵乘法(Batched Matrix Multiplication with Bias): 计算公式为: $$ \text{output}[b, i, j] = \sum_k \text{x}[b, i, k] \cdot \text{y}[k, j] + \text{z}[b, i, j] $$ 其中:

  • x 的形状为 (A, B)

  • y 的形状为 (B, C)

  • z(偏置)的形状为 (A, C)

  • 输出 output 的形状为 (A, C)

该内核假设单个 block 负责整个输出矩阵的计算,适用于小规模矩阵(A、B、C 较小且能被当前程序块完全覆盖)。

import pytest
import torch
import torch_npu
import triton
import triton.language as tl


@triton.jit
def triton_dot_2_Bias(
    output_ptr,   # 输出张量指针,形状 (A, C)
    x_ptr,        # 输入张量 x 指针,形状 (A, B)
    y_ptr,        # 输入张量 y 指针,形状 (B, C)
    z_ptr,        # 偏置张量 z 指针,形状 (A, C)
    A: tl.constexpr,  # 第一维度大小(batch / 行数)
    B: tl.constexpr,  # 共享维度(x 的列数,y 的行数)
    C: tl.constexpr   # 第二维度大小(列数)
):
    # 创建索引向量
    bidx = tl.arange(0, A)  # [0, 1, ..., A-1],用于行维度
    cidx = tl.arange(0, B)  # [0, 1, ..., B-1],用于 x 的列 / y 的行
    didx = tl.arange(0, C)  # [0, 1, ..., C-1],用于列维度

    # 构造 x 的线性索引:(A, B) -> 展平为 A*B
    Xidx = bidx[:, None] * B + cidx[None, :]  # 广播形成 (A, B) 索引网格

    # 构造 y 的线性索引:(B, C) -> 展平为 B*C
    Yidx = cidx[:, None] * C + didx[None, :]  # (B, C) 索引网格

    # 构造 z 和 output 的线性索引:(A, C)
    Zidx = bidx[:, None] * C + didx[None, :]  # (A, C) 索引网格

    # 从全局内存加载数据
    X = tl.load(x_ptr + Xidx)  # 加载 (A, B) 子块
    Y = tl.load(y_ptr + Yidx)  # 加载 (B, C) 子块
    Z = tl.load(z_ptr + Zidx)  # 加载偏置 (A, C)

    # 执行矩阵乘法并加上偏置
    ret = tl.dot(X, Y) + Z  # tl.dot 执行 (A, B) × (B, C) → (A, C)

    # 写回结果到全局内存
    oidx = bidx[:, None] * C + didx[None, :]  # 与 Zidx 相同,可复用
    tl.store(output_ptr + oidx, ret)

工具方法

以下辅助函数用于支持 Triton 内核的测试与验证,包括 PyTorch 参考实现、数据类型映射、随机张量生成及结果校验。

def torch_dot_Bias(x0, x1, bias):
    """PyTorch 参考实现:执行矩阵乘法并加上偏置项。"""
    res = torch.matmul(x0, x1) + bias
    return res

def get_torch_typename(dtype):
    """将字符串形式的数据类型映射为对应的 torch.dtype。"""
    if dtype == 'float32':
        tyname = torch.float32
    elif dtype == 'int32':
        tyname = torch.int32
    elif dtype == 'int64':
        tyname = torch.int64
    elif dtype == 'float16':
        tyname = torch.float16
    elif dtype == 'int16':
        tyname = torch.int16
    elif dtype == 'int8':
        tyname = torch.int8
    elif dtype == 'bool':
        tyname = torch.bool
    elif dtype == 'bfloat16':
        tyname = torch.bfloat16
    else:
        raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype))
    return tyname

def generate_tensor(shape, dtype):
     """根据指定形状和数据类型生成随机张量,适配不同数值类型的取值范围。"""
    if dtype == 'float32' or dtype == 'float16' or dtype == 'bfloat16':
        return torch.randn(size=shape, dtype=eval('torch.' + dtype))
    elif dtype == 'int32' or dtype == 'int64' or dtype == 'int16':
        return torch.randint(low=0, high=2000, size=shape, dtype=eval('torch.' + dtype))
    elif dtype == 'int8':
        return torch.randint(low=0, high=127, size=shape, dtype=eval('torch.' + dtype))
    elif dtype == 'bool':
        return torch.randint(low=0, high=2, size=shape).bool()
    else:
        raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype))

def validate_cmp(dtype, y_cal, y_ref):
    """在 NPU 上比较 Triton 计算结果与 PyTorch 参考结果,按数据类型设置容差或严格相等。"""
    y_cal=y_cal.npu()
    y_ref=y_ref.npu()
    if dtype == 'float16':
        torch.testing.assert_close(y_ref, y_cal,  rtol=1e-03, atol=1e-03, equal_nan=True)
    elif dtype == 'bfloat16':
        torch.testing.assert_close(y_ref.to(torch.float32), y_cal.to(torch.float32),  rtol=1e-03, atol=1e-03, equal_nan=True)
    elif dtype == 'float32':
        torch.testing.assert_close(y_ref, y_cal,  rtol=1e-04, atol=1e-04, equal_nan=True)
    elif dtype == 'int32' or dtype == 'int64' or dtype == 'int16' or dtype == 'int8':
        assert torch.equal(y_cal, y_ref)
    elif dtype == 'bool':
        assert torch.equal(y_cal, y_ref)
    else:
        raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype))

参数化测试

使用 pytesttriton_dot_2_Bias 内核进行参数化功能验证,覆盖不同矩阵维度和数据类型组合。

# 测试用例配置:(A, B, C) 表示矩阵 x: (A,B), y: (B,C), bias/output: (A,C)
testlist = [
    (16, 16, 16),
]

# 支持的数据类型列表(当前仅 float16)
typelist = ['float16',]

@pytest.mark.parametrize('A, B, C', testlist)
@pytest.mark.parametrize('sigtype', typelist)
def test_dot_2_Bias(sigtype, A, B, C):
    """对 triton_dot_2_Bias 内核进行端到端功能测试。"""
    dtype = get_torch_typename(sigtype)

    # 生成输入张量并移至 NPU
    x0 = generate_tensor(shape=(A, B), dtype=sigtype).npu()
    x1 = generate_tensor(shape=(B, C), dtype=sigtype).npu()

    # 偏置项统一用 float32 生成(避免整数偏置导致精度问题)
    if 'int' in sigtype:
        bias = generate_tensor(shape=(A, C), dtype='int32').npu()
        # 整数输入需转为 float32 计算后再转回目标类型
        ans = torch_dot_Bias(x0.to(torch.float32), x1.to(torch.float32), bias.to(torch.float32)).to(dtype)
    else:
        bias = generate_tensor(shape=(A, C), dtype='float32').npu()
        ans = torch_dot_Bias(x0, x1, bias).to(eval(f"torch.{dtype}"))

    # 初始化输出张量
    output = torch.zeros((A, C), dtype=dtype).npu()

    # 启动 Triton 内核(grid=(1,1,1),单 block 执行)
    triton_dot_2_Bias[1, 1, 1](output, x0, x1, bias, A, B, C, debug=True)

    # 验证结果正确性
    validate_cmp(sigtype, output, ans)
    print(f"Test matmul with dtype={sigtype}, shape=({A},{B},{C}) PASSED!")


if __name__ == "__main__":
    # 支持直接运行单个测试用例(便于调试)
    test_dot_2_Bias("float16", 16, 16, 16)

输出示例:

Test matmul with dtype=float16, shape=(16,16,16) PASSED!

上面输出日志表明Triton和Pytorch上的输出结果完全一致。