triton.language.load

1. OP 概述

原型:

triton.language.load(
 pointer,
 mask=None,
 other=None,
 boundary_check=(),
 padding_option='',
 cache_modifier='',
 eviction_policy='',
 volatile=False,
 _semantic=None
)

简介:返回一个Tensor/Scalar,其值从GlobalMemory中pointer参数指向的位置加载。

2. OP 规格

2.1 参数说明

参数名

类型

说明

pointer

triton.PointerType
tensor<triton.PointerType>
triton.PointerType<tensor>(来源于tl.make_block_ptr

指向GM上待读取数据的指针

mask

int1tensor<int1>

可选参数,当且仅当pointer 不来源于tl.make_block_ptr时可传入
mask[i]==False ,则不会读取pointer[i]指向的数据,是True则正常读取
pointer来源于tl.make_block_ptr,则mask必须是None

other

tensorscalar

可选参数,当且仅当mask!=None时可传入
mask[i]==False ,将返回值的第i个位置设置为other[i]other(若otherscalar类型), 需要支持tensor,因为tritonGPU社区上是tensor和scalar都支持的

boundary_check

tuple(int)

可选参数,当且仅当pointer来源于tl.make_block_ptr时可传入
整数元组,指示需要做边界检查的维度

padding_option

"""zero""nan"

可选参数,当且仅当boundary_check不为空时可传入
表示访问越界时填充的值

cache_modifier

"""ca""cg"

可选参数,控制NVIDIA PTX上的cache选项,对Ascend硬件无效

eviction_policy

str

控制NVIDIA PTX的eviction策略, 对Ascend硬件无效

volatile

str

控制NVIDIA PTX的volatile选项, 对Ascend硬件无效

_semantic

-

保留参数,暂不支持外部调用

当前910代际均还不支持cache_modifier,eviction_policy, volatile等参数

2.2 支持规格

2.2.1 DataType 支持

int8

int16

int32

uint8

uint16

uint32

uint64

int64

fp16

fp32

fp64

bf16

bool

GPU

Ascend A2/A3

×

×

×

×

×

结论:Ascend 对比 GPU 缺失uint8、uint16、uint32、uint64、fp64的支持能力(硬件限制)。

2.2.2 Shape 支持

支持维度范围

GPU

支持scalar和1~5维tensor

Ascend A2/A3

支持scalar和1~5维 tensor

结论:在 Shape 方面,GPU 与 Ascend 平台无差异,均支持 1 至 5 维张量。

2.2.3 社区约束

  1. pointer是一个单指针:

    • 此时tl.load返回一个标量

    • maskother必须是标量

    • other会隐式类型转换成pointer.dtype.element_ty的数据类型

    • 此时不允许传入boundary_checkpadding_option

  2. pointer是一个N-Dimensional tensor:

    • 此时tl.load返回一个与pointershape相同的N-Dimensional tensor

    • maskother会隐式广播到和pointer相同的shape

    • 此时不允许传入boundary_checkpadding_option

  3. pointer来自于tl.make_block_ptr:

    • 此时maskother 必须是None

    • 此时可以通过boundary_checkpadding_option设置边界检查和越界补充值

2.3 特殊限制说明

相对社区能力缺失且无法实现

Ascend 对比 GPU 缺失uint8、uint16、uint32、uint64、fp64的支持能力(硬件限制)。

差异点

描述

解决途径

不支持padding_option入参

当前使用的社区分支新增padding_option入参,用于越界元素填充策略。

可软件开发支持

与分支、循环语句搭配使用时的泛化性问题

当前tl.load的pointermask的计算过程,如果涉及较复杂的循环和分支语句,可能会出现编译问题

大量泛化测试暴露问题,迭代解决

2.4 使用方法

以下示例中通过triton_ldst_indirect_07_kerneltriton_ldst_indirect_07_func的配合调用,实现了torch_ldst_indirect_07_func的功能:

@triton.jit
def triton_ldst_indirect_07_kernel(
    out_ptr0, in_ptr0, in_ptr1, in_ptr2, stride_in_r,
    XS: tl.constexpr, RS: tl.constexpr
):
    pid = tl.program_id(0)
    in_idx0 = pid * XS + tl.arange(0, XS)
    in_idx1 = tl.arange(0, RS)
    tmp0 = tl.load(in_ptr0 + in_idx0)
    tmp1 = tl.load(in_ptr1 + in_idx1)
    in_idx2 = tmp0[:, None] * stride_in_r + tmp1[None, :]
    tmp2 = tl.load(in_ptr2 + in_idx2)
    out0_idx = in_idx0[:, None] * RS + in_idx1[None, :]
    tl.store(out_ptr0 + out0_idx, tmp2)

def triton_ldst_indirect_07_func(xr, xc, x2, xs, rs):
    nr = x2.size()[0]
    nc = xc.numel()
    stride_in_r = x2.stride()[0]
    assert nr == xs, "test only single core"
    y0 = torch.empty((nr, nc), dtype=x2.dtype, device=x2.device)
    triton_ldst_indirect_07_kernel[nr // xs, 1, 1](
        y0, xr, xc, x2, stride_in_r, XS = xs, RS = rs)
    return y0

def torch_ldst_indirect_07_func(xr, xc, x2):
    flatten_idx = (xr[:, None] * x2.stride()[0] + xc[None, :]).flatten()
    extracted = x2.flatten()[flatten_idx].reshape([xr.numel(), xc.numel()])
    return extracted

DEV = "npu"
DTYPE = torch.float32
offset = 8
N0, N1 = 16, 32
blocksize = 4
lowdimsize = N0
assert N1 >= N0+offset, "N1 must be >= N0+offset"
assert N0 == lowdimsize, "N0 must be == lowdimsize"
xc = offset + torch.arange(0, N0, device=DEV)
xr = torch.arange(0, blocksize, device=DEV)
x2 = torch.randn((blocksize, N1), dtype=DTYPE, device=DEV)
torch_ref = torch_ldst_indirect_07_func(xr, xc, x2)
triton_cal = triton_ldst_indirect_07_func(xr, xc, x2, blocksize, lowdimsize)
torch.testing.assert_close(triton_cal, torch_ref)