triton.language.device_assert
使用device_assert需要将环境变量TRITON_DEBUG的值设置为非0才能生效。
1. 函数概述
device_assert 用于在GPU运行时从设备端进行断言检查,如果条件不满足则输出错误信息。
triton.language.device_assert(cond, msg='', _semantic=None)
2. 规格
2.1 参数说明
参数 |
类型 |
默认值 |
含义说明 |
|---|---|---|---|
|
|
必需 |
运行时需要断言的条件表达式 |
|
|
|
断言失败时显示的错误消息 |
|
- |
- |
保留参数,暂不支持外部调用 |
2.2 类型支持
A3:
int8 |
int16 |
int32 |
uint8 |
uint16 |
uint32 |
uint64 |
int64 |
fp16 |
fp32 |
fp64 |
bf16 |
bool |
|
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
GPU |
× |
× |
× |
× |
× |
× |
× |
× |
× |
× |
× |
× |
✓ |
Ascend A2/A3 |
× |
× |
× |
× |
× |
× |
× |
× |
× |
× |
× |
× |
✓ |
2.3 使用方法
import triton.language as tl
@triton.jit
def basic_device_assert_example(x_ptr, BLOCK_SIZE: tl.constexpr):
# 基本断言:检查程序ID
pid = tl.program_id(0)
tl.device_assert(pid >= 0, "Program ID must be non-negative")
offsets = tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + offsets)
# 检查数据有效性(比如检查张量中没有负值)
tl.device_assert(tl.min(x) >= 0, "All values must be non-negative")