max_autotune 使用示例
max_autotune 是 Triton-Ascend 提供的增强自动调优装饰器,旨在简化多参数调优的代码编写。与社区版 autotune 要求用户手动枚举所有 triton.Config 不同,max_autotune 允许用户仅提供少量基础配置(如分块大小),并自动将相关的编译器选项(如 num_stages、enable_hivm_auto_cv_balance 等)纳入最优组合的搜索空间。用户也可以通过参数列表显式控制搜索范围。
适用场景:Ascend NPU 上的 cube、mix、vector 算子,尤其适合需要同时调整多个硬件相关参数的场景。
基本使用示例
以下示例演示了使用 max_autotune 对一个简单的向量加法 kernel 进行自动调优。与社区版 autotune 相比,max_autotune 还会自动将不同的编译器选项纳入调优空间,无需用户手动指定。
import torch
import torch_npu
import triton
import triton.language as tl
from triton.backends.ascend.runtime import max_autotune
def test_max_autotune():
# 基础配置:只需提供分块大小,其他调优参数由装饰器自动生成
base_configs = [
triton.Config({'BLOCK_SIZE': 128}),
triton.Config({'BLOCK_SIZE': 256}),
]
@max_autotune(
configs=base_configs,
key=["numel"],
kernel_type="vector", # 算子类型:cube / mix / vector, 默认为mix
)
@triton.jit
def triton_calc_kernel(
out_ptr0, in_ptr0, in_ptr1, numel,
BLOCK_SIZE: tl.constexpr
):
pid = tl.program_id(0)
idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = idx < numel
# 模拟计算负载
for i in range(10000):
tmp0 = tl.load(in_ptr0 + idx, mask=mask, other=0.0)
tmp1 = tl.load(in_ptr1 + idx, mask=mask, other=0.0)
tmp2 = tl.math.exp(tmp0) + tmp1 + i
tl.store(out_ptr0 + idx, tmp2, mask=mask)
# 封装调用函数
def triton_calc_func(x0, x1):
n = x0.numel()
y0 = torch.empty_like(x0)
grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),)
triton_calc_kernel[grid](y0, x0, x1, n)
return y0
# 与 PyTorch 参考结果对比
def torch_calc_func(x0, x1):
return torch.exp(x0) + x1 + 10000 - 1
DEV = "npu"
DTYPE = torch.float32
N = 192 * 1024
x0 = torch.randn((N,), dtype=DTYPE, device=DEV)
x1 = torch.randn((N,), dtype=DTYPE, device=DEV)
torch_ref = torch_calc_func(x0, x1)
triton_cal = triton_calc_func(x0, x1)
torch.testing.assert_close(triton_cal, torch_ref)
if __name__ == "__main__":
test_max_autotune()
print("success: test_max_autotune")
进阶使用:精确控制调优参数
用户可以通过 **tuning_params 显式指定需要调优的编译器选项及其取值列表;未指定的参数会使用内置默认值。以下示例展示了如何对多个参数进行组合搜索。
from triton.backends.ascend.runtime import max_autotune
def test_max_autotune():
# 基础配置:只需提供分块大小,其他调优参数由装饰器自动生成
base_configs = [
triton.Config({'BLOCK_SIZE': 128}),
triton.Config({'BLOCK_SIZE': 256}),
]
@max_autotune(
configs=base_configs, # 基础配置列表
key=["numel"], # 当 numel 变化时触发重新调优
kernel_type="vector", # 算子类型:cube / mix / vector
# 以下参数为可选的调优列表,不提供时使用内置默认值
num_stages=[1, 2],
enable_ubuf_saving=[True, False]
)
@triton.jit
def triton_calc_kernel(
out_ptr0, in_ptr0, in_ptr1, numel,
BLOCK_SIZE: tl.constexpr,
**META
):
pass