max_autotune 使用示例

max_autotune 是 Triton-Ascend 提供的增强自动调优装饰器,旨在简化多参数调优的代码编写。与社区版 autotune 要求用户手动枚举所有 triton.Config 不同,max_autotune 允许用户仅提供少量基础配置(如分块大小),并自动将相关的编译器选项(如 num_stagesenable_hivm_auto_cv_balance 等)纳入最优组合的搜索空间。用户也可以通过参数列表显式控制搜索范围。

适用场景:Ascend NPU 上的 cubemixvector 算子,尤其适合需要同时调整多个硬件相关参数的场景。

基本使用示例

以下示例演示了使用 max_autotune 对一个简单的向量加法 kernel 进行自动调优。与社区版 autotune 相比,max_autotune 还会自动将不同的编译器选项纳入调优空间,无需用户手动指定。

import torch
import torch_npu
import triton
import triton.language as tl
from triton.backends.ascend.runtime import max_autotune

def test_max_autotune():

    # 基础配置:只需提供分块大小,其他调优参数由装饰器自动生成
    base_configs = [
        triton.Config({'BLOCK_SIZE': 128}),
        triton.Config({'BLOCK_SIZE': 256}),
    ]

    @max_autotune(
        configs=base_configs,
        key=["numel"],
        kernel_type="vector",          # 算子类型:cube / mix / vector, 默认为mix
    )
    @triton.jit
    def triton_calc_kernel(
        out_ptr0, in_ptr0, in_ptr1, numel,
        BLOCK_SIZE: tl.constexpr
    ):
        pid = tl.program_id(0)
        idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
        mask = idx < numel

        # 模拟计算负载
        for i in range(10000):
            tmp0 = tl.load(in_ptr0 + idx, mask=mask, other=0.0)
            tmp1 = tl.load(in_ptr1 + idx, mask=mask, other=0.0)
            tmp2 = tl.math.exp(tmp0) + tmp1 + i
            tl.store(out_ptr0 + idx, tmp2, mask=mask)

    # 封装调用函数
    def triton_calc_func(x0, x1):
        n = x0.numel()
        y0 = torch.empty_like(x0)
        grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),)
        triton_calc_kernel[grid](y0, x0, x1, n)
        return y0

    # 与 PyTorch 参考结果对比
    def torch_calc_func(x0, x1):
        return torch.exp(x0) + x1 + 10000 - 1

    DEV = "npu"
    DTYPE = torch.float32
    N = 192 * 1024
    x0 = torch.randn((N,), dtype=DTYPE, device=DEV)
    x1 = torch.randn((N,), dtype=DTYPE, device=DEV)
    torch_ref = torch_calc_func(x0, x1)
    triton_cal = triton_calc_func(x0, x1)
    torch.testing.assert_close(triton_cal, torch_ref)

if __name__ == "__main__":
    test_max_autotune()
    print("success: test_max_autotune")

进阶使用:精确控制调优参数

用户可以通过 **tuning_params 显式指定需要调优的编译器选项及其取值列表;未指定的参数会使用内置默认值。以下示例展示了如何对多个参数进行组合搜索。

from triton.backends.ascend.runtime import max_autotune

def test_max_autotune():

    # 基础配置:只需提供分块大小,其他调优参数由装饰器自动生成
    base_configs = [
        triton.Config({'BLOCK_SIZE': 128}),
        triton.Config({'BLOCK_SIZE': 256}),
    ]

    @max_autotune(
        configs=base_configs,          # 基础配置列表
        key=["numel"],                 # 当 numel 变化时触发重新调优
        kernel_type="vector",          # 算子类型:cube / mix / vector
        # 以下参数为可选的调优列表,不提供时使用内置默认值
        num_stages=[1, 2],
        enable_ubuf_saving=[True, False]
    )
    @triton.jit
    def triton_calc_kernel(
        out_ptr0, in_ptr0, in_ptr1, numel,
        BLOCK_SIZE: tl.constexpr,
        **META
    ):
        pass