矩阵乘法 (Matrix Multiplication)
在本节中,我们展示了使用 Triton 进行矩阵乘法的内核实现。
计算内核
以下 Triton 内核实现了一个带偏置项的批量矩阵乘法(Batched Matrix Multiplication with Bias): 计算公式为: $$ \text{output}[b, i, j] = \sum_k \text{x}[b, i, k] \cdot \text{y}[k, j] + \text{z}[b, i, j] $$ 其中:
x的形状为(A, B)y的形状为(B, C)z(偏置)的形状为(A, C)输出
output的形状为(A, C)
该内核假设单个 block 负责整个输出矩阵的计算,适用于小规模矩阵(A、B、C 较小且能被当前程序块完全覆盖)。
import pytest
import torch
import torch_npu
import triton
import triton.language as tl
@triton.jit
def triton_dot_2_Bias(
output_ptr, # 输出张量指针,形状 (A, C)
x_ptr, # 输入张量 x 指针,形状 (A, B)
y_ptr, # 输入张量 y 指针,形状 (B, C)
z_ptr, # 偏置张量 z 指针,形状 (A, C)
A: tl.constexpr, # 第一维度大小(batch / 行数)
B: tl.constexpr, # 共享维度(x 的列数,y 的行数)
C: tl.constexpr # 第二维度大小(列数)
):
# 创建索引向量
bidx = tl.arange(0, A) # [0, 1, ..., A-1],用于行维度
cidx = tl.arange(0, B) # [0, 1, ..., B-1],用于 x 的列 / y 的行
didx = tl.arange(0, C) # [0, 1, ..., C-1],用于列维度
# 构造 x 的线性索引:(A, B) -> 展平为 A*B
Xidx = bidx[:, None] * B + cidx[None, :] # 广播形成 (A, B) 索引网格
# 构造 y 的线性索引:(B, C) -> 展平为 B*C
Yidx = cidx[:, None] * C + didx[None, :] # (B, C) 索引网格
# 构造 z 和 output 的线性索引:(A, C)
Zidx = bidx[:, None] * C + didx[None, :] # (A, C) 索引网格
# 从全局内存加载数据
X = tl.load(x_ptr + Xidx) # 加载 (A, B) 子块
Y = tl.load(y_ptr + Yidx) # 加载 (B, C) 子块
Z = tl.load(z_ptr + Zidx) # 加载偏置 (A, C)
# 执行矩阵乘法并加上偏置
ret = tl.dot(X, Y) + Z # tl.dot 执行 (A, B) × (B, C) → (A, C)
# 写回结果到全局内存
oidx = bidx[:, None] * C + didx[None, :] # 与 Zidx 相同,可复用
tl.store(output_ptr + oidx, ret)
工具方法
以下辅助函数用于支持 Triton 内核的测试与验证,包括 PyTorch 参考实现、数据类型映射、随机张量生成及结果校验。
def torch_dot_Bias(x0, x1, bias):
"""PyTorch 参考实现:执行矩阵乘法并加上偏置项。"""
res = torch.matmul(x0, x1) + bias
return res
def get_torch_typename(dtype):
"""将字符串形式的数据类型映射为对应的 torch.dtype。"""
if dtype == 'float32':
tyname = torch.float32
elif dtype == 'int32':
tyname = torch.int32
elif dtype == 'int64':
tyname = torch.int64
elif dtype == 'float16':
tyname = torch.float16
elif dtype == 'int16':
tyname = torch.int16
elif dtype == 'int8':
tyname = torch.int8
elif dtype == 'bool':
tyname = torch.bool
elif dtype == 'bfloat16':
tyname = torch.bfloat16
else:
raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype))
return tyname
def generate_tensor(shape, dtype):
"""根据指定形状和数据类型生成随机张量,适配不同数值类型的取值范围。"""
if dtype == 'float32' or dtype == 'float16' or dtype == 'bfloat16':
return torch.randn(size=shape, dtype=eval('torch.' + dtype))
elif dtype == 'int32' or dtype == 'int64' or dtype == 'int16':
return torch.randint(low=0, high=2000, size=shape, dtype=eval('torch.' + dtype))
elif dtype == 'int8':
return torch.randint(low=0, high=127, size=shape, dtype=eval('torch.' + dtype))
elif dtype == 'bool':
return torch.randint(low=0, high=2, size=shape).bool()
else:
raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype))
def validate_cmp(dtype, y_cal, y_ref):
"""在 NPU 上比较 Triton 计算结果与 PyTorch 参考结果,按数据类型设置容差或严格相等。"""
y_cal=y_cal.npu()
y_ref=y_ref.npu()
if dtype == 'float16':
torch.testing.assert_close(y_ref, y_cal, rtol=1e-03, atol=1e-03, equal_nan=True)
elif dtype == 'bfloat16':
torch.testing.assert_close(y_ref.to(torch.float32), y_cal.to(torch.float32), rtol=1e-03, atol=1e-03, equal_nan=True)
elif dtype == 'float32':
torch.testing.assert_close(y_ref, y_cal, rtol=1e-04, atol=1e-04, equal_nan=True)
elif dtype == 'int32' or dtype == 'int64' or dtype == 'int16' or dtype == 'int8':
assert torch.equal(y_cal, y_ref)
elif dtype == 'bool':
assert torch.equal(y_cal, y_ref)
else:
raise ValueError('Invalid parameter \"dtype\" is found : {}'.format(dtype))
参数化测试
使用 pytest 对 triton_dot_2_Bias 内核进行参数化功能验证,覆盖不同矩阵维度和数据类型组合。
# 测试用例配置:(A, B, C) 表示矩阵 x: (A,B), y: (B,C), bias/output: (A,C)
testlist = [
(16, 16, 16),
]
# 支持的数据类型列表(当前仅 float16)
typelist = ['float16',]
@pytest.mark.parametrize('A, B, C', testlist)
@pytest.mark.parametrize('sigtype', typelist)
def test_dot_2_Bias(sigtype, A, B, C):
"""对 triton_dot_2_Bias 内核进行端到端功能测试。"""
dtype = get_torch_typename(sigtype)
# 生成输入张量并移至 NPU
x0 = generate_tensor(shape=(A, B), dtype=sigtype).npu()
x1 = generate_tensor(shape=(B, C), dtype=sigtype).npu()
# 偏置项统一用 float32 生成(避免整数偏置导致精度问题)
if 'int' in sigtype:
bias = generate_tensor(shape=(A, C), dtype='int32').npu()
# 整数输入需转为 float32 计算后再转回目标类型
ans = torch_dot_Bias(x0.to(torch.float32), x1.to(torch.float32), bias.to(torch.float32)).to(dtype)
else:
bias = generate_tensor(shape=(A, C), dtype='float32').npu()
ans = torch_dot_Bias(x0, x1, bias).to(eval(f"torch.{dtype}"))
# 初始化输出张量
output = torch.zeros((A, C), dtype=dtype).npu()
# 启动 Triton 内核(grid=(1,1,1),单 block 执行)
triton_dot_2_Bias[1, 1, 1](output, x0, x1, bias, A, B, C, debug=True)
# 验证结果正确性
validate_cmp(sigtype, output, ans)
print(f"Test matmul with dtype={sigtype}, shape=({A},{B},{C}) PASSED!")
if __name__ == "__main__":
# 支持直接运行单个测试用例(便于调试)
test_dot_2_Bias("float16", 16, 16, 16)
输出示例:
Test matmul with dtype=float16, shape=(16,16,16) PASSED!
上面输出日志表明Triton和Pytorch上的输出结果完全一致。