From b9f1f0fbf9e004e49f444f44d89552b9e0f2bb6a Mon Sep 17 00:00:00 2001 From: yanminghui Date: Fri, 22 Aug 2025 03:30:47 +0000 Subject: [PATCH 1/2] mask optimization for argmax --- .../lib/TritonToLinalg/LoadStoreConverter.cpp | 13 +++++- .../triton/Dialect/Triton/IR/TritonOps.td | 22 +++++++++- triton_patch/lib/Dialect/Triton/IR/Ops.cpp | 44 +++++++++++++++++++ triton_patch/python/src/ir.cc | 7 +-- .../python/triton_patch/language/semantic.py | 4 +- 5 files changed, 83 insertions(+), 7 deletions(-) diff --git a/ascend/triton-adapter/lib/TritonToLinalg/LoadStoreConverter.cpp b/ascend/triton-adapter/lib/TritonToLinalg/LoadStoreConverter.cpp index 210e135..9aee1e0 100644 --- a/ascend/triton-adapter/lib/TritonToLinalg/LoadStoreConverter.cpp +++ b/ascend/triton-adapter/lib/TritonToLinalg/LoadStoreConverter.cpp @@ -279,6 +279,17 @@ LoadConverter::matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, op, "can not lower uncontinuout masked loads"); } + if (op.getOptMask() && mask && other) { + auto scalarOther = + mlir::ConverterUtils::getScalarValue(other, loc, rewriter); + assert( + scalarOther && + "other value used in masked load produced by unsupported instruction!"); + auto loc = allocOp->getLoc(); + rewriter.create(loc, ValueRange{scalarOther},ValueRange{allocOp}); + other = NULL; + } + if (other) { auto scalarOther = mlir::ConverterUtils::getScalarValue(other, loc, rewriter); @@ -926,4 +937,4 @@ StoreConverter::matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, rewriter.eraseOp(op); return success(); } -} // namespace LoadStoreConverter \ No newline at end of file +} // namespace LoadStoreConverter diff --git a/triton_patch/include/triton/Dialect/Triton/IR/TritonOps.td b/triton_patch/include/triton/Dialect/Triton/IR/TritonOps.td index 5e1118e..62aba5a 100644 --- a/triton_patch/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/triton_patch/include/triton/Dialect/Triton/IR/TritonOps.td @@ -253,12 +253,32 @@ def TT_LoadOp : TT_Op<"load", [ OptionalAttr:$padding, DefaultValuedAttr:$cache, DefaultValuedAttr:$evict, - DefaultValuedAttr:$isVolatile + DefaultValuedAttr:$isVolatile, + DefaultValuedAttr:$optMask ); let results = (outs TT_Type:$result); let builders = [ + // A tensor of pointers or a pointer to a scalar + OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile, "bool":$optMask)>, + // A tensor pointer with boundary check and padding + OpBuilder<(ins "Value":$ptr, "ArrayRef":$boundaryCheck, + "std::optional":$padding, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile, "bool":$optMask)>, + // A tensor of pointers or a pointer to a scalar with mask + OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile, "bool":$optMask)>, + // A tensor of pointers or a pointer to a scalar with mask and other + OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile, "bool":$optMask)>, + // A utility function to build the operation with all attributes + OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, + "ArrayRef":$boundaryCheck, + "std::optional":$padding, "triton::CacheModifier":$cache, + "triton::EvictionPolicy":$evict, "bool":$isVolatile, "bool":$optMask)>, + // A tensor of pointers or a pointer to a scalar OpBuilder<(ins "Value":$ptr, "triton::CacheModifier":$cache, "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, diff --git a/triton_patch/lib/Dialect/Triton/IR/Ops.cpp b/triton_patch/lib/Dialect/Triton/IR/Ops.cpp index 7873e99..0018072 100644 --- a/triton_patch/lib/Dialect/Triton/IR/Ops.cpp +++ b/triton_patch/lib/Dialect/Triton/IR/Ops.cpp @@ -44,6 +44,13 @@ void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, cache, evict, isVolatile); } +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + CacheModifier cache, EvictionPolicy evict, bool isVolatile, bool optMask) { + LoadOp::build(builder, state, ptr, /*mask=*/{}, /*other=*/{}, + /*boundaryCheck=*/ArrayRef{}, /*padding=*/std::nullopt, + cache, evict, isVolatile, optMask); +} + void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, ArrayRef boundaryCheck, std::optional padding, CacheModifier cache, @@ -52,6 +59,14 @@ void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, padding, cache, evict, isVolatile); } +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + ArrayRef boundaryCheck, + std::optional padding, CacheModifier cache, + EvictionPolicy evict, bool isVolatile, bool optMask) { + LoadOp::build(builder, state, ptr, /*mask=*/{}, /*other=*/{}, boundaryCheck, + padding, cache, evict, isVolatile, optMask); +} + void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, Value mask, CacheModifier cache, EvictionPolicy evict, bool isVolatile) { @@ -60,6 +75,14 @@ void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, /*padding=*/std::nullopt, cache, evict, isVolatile); } +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value mask, CacheModifier cache, EvictionPolicy evict, + bool isVolatile, bool optMask) { + LoadOp::build(builder, state, ptr, mask, /*other=*/{}, + /*boundaryCheck=*/ArrayRef{}, + /*padding=*/std::nullopt, cache, evict, isVolatile, optMask); +} + void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, Value mask, Value other, CacheModifier cache, EvictionPolicy evict, bool isVolatile) { @@ -68,6 +91,14 @@ void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, /*padding=*/std::nullopt, cache, evict, isVolatile); } +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value mask, Value other, CacheModifier cache, + EvictionPolicy evict, bool isVolatile, bool optMask) { + LoadOp::build(builder, state, ptr, mask, other, + /*boundaryCheck=*/ArrayRef{}, + /*padding=*/std::nullopt, cache, evict, isVolatile, optMask); +} + void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, Value mask, Value other, ArrayRef boundaryCheck, std::optional padding, CacheModifier cache, @@ -81,6 +112,19 @@ void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, evict, isVolatile); } +void LoadOp::build(OpBuilder &builder, OperationState &state, Value ptr, + Value mask, Value other, ArrayRef boundaryCheck, + std::optional padding, CacheModifier cache, + EvictionPolicy evict, bool isVolatile, bool optMask) { + auto paddingAttr = + padding.has_value() + ? PaddingOptionAttr::get(builder.getContext(), padding.value()) + : PaddingOptionAttr(); + LoadOp::build(builder, state, ptr, mask, other, + builder.getDenseI32ArrayAttr(boundaryCheck), paddingAttr, cache, + evict, isVolatile, optMask); +} + // load(ptr, splat(1), ...) -> load(ptr, ...) // load(ptr, splat(0), other, ...) -> other struct CanonicalizeMaskedLoadPattern : public OpRewritePattern { diff --git a/triton_patch/python/src/ir.cc b/triton_patch/python/src/ir.cc index a2fd115..f2556f9 100644 --- a/triton_patch/python/src/ir.cc +++ b/triton_patch/python/src/ir.cc @@ -1266,10 +1266,11 @@ void init_triton_ir(py::module &&m) { .def("create_masked_load", [](TritonOpBuilder &self, Value &ptrs, Value &mask, std::optional &other, CacheModifier cacheModifier, - EvictionPolicy evictionPolicy, bool isVolatile) -> Value { - return self.create(ptrs, mask, other.value_or(Value()), + EvictionPolicy evictionPolicy, bool isVolatile, bool optMask) -> Value { + return self.create(ptrs, mask, other.value_or(Value()), cacheModifier, evictionPolicy, - isVolatile); + isVolatile, optMask); + }) .def("create_masked_store", [](TritonOpBuilder &self, Value &ptrs, Value &val, Value &mask, diff --git a/triton_patch/python/triton_patch/language/semantic.py b/triton_patch/python/triton_patch/language/semantic.py index 132efa6..65037a8 100644 --- a/triton_patch/python/triton_patch/language/semantic.py +++ b/triton_patch/python/triton_patch/language/semantic.py @@ -380,7 +380,7 @@ def not_(input: tl.tensor, builder: ir.builder): return invert(input, builder) -def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder): +def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, mask_opt, builder): # Load by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` if not ptr.type.scalar.is_ptr(): raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.load`") @@ -438,7 +438,7 @@ def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_ else: ret = tl.tensor( builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache, eviction, - is_volatile), dst_ty) + is_volatile, mask_opt), dst_ty) # Do not cast back to int1 when is_bool=true. We directly use the int8 tensor given by tl.load if is_bool: ret.was_bool_to_int8 = True -- Gitee From ce405c6a9969cc515e7d611a9368639d27082ce7 Mon Sep 17 00:00:00 2001 From: yanminghui Date: Mon, 22 Sep 2025 07:42:34 +0000 Subject: [PATCH 2/2] Use hint instead of attribute --- .../python/triton_patch/compiler/code_generator.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/triton_patch/python/triton_patch/compiler/code_generator.py b/triton_patch/python/triton_patch/compiler/code_generator.py index d8ee4fd..397b427 100644 --- a/triton_patch/python/triton_patch/compiler/code_generator.py +++ b/triton_patch/python/triton_patch/compiler/code_generator.py @@ -1132,6 +1132,13 @@ class CodeGenerator(ast.NodeVisitor): kws = dict(self.visit(keyword) for keyword in node.keywords) args = [self.visit(arg) for arg in node.args] + + # Get current line number and hints + line_num = node.lineno + function_def = self.jit_fn.parse() + line_flagtree_hints = getattr(function_def.body[0], 'line_flagtree_hints', {}) + flagtree_hints = line_flagtree_hints.get(line_num) + if isinstance(fn, JITFunction): _check_fn_args(node, fn, args) return self.call_JitFunction(fn, args, kws) @@ -1141,6 +1148,12 @@ class CodeGenerator(ast.NodeVisitor): if '_generator' in sig.parameters: extra_kwargs['_generator'] = self try: + # Special handling for tl.load with hints + if fn.__name__ == "load" and flagtree_hints is not None and 'mask_opt' in flagtree_hints: + print(f"tl.load at line {line_num} has attribute {flagtree_hints}") + if 'mask_opt' not in kws: + kws['mask_opt'] = True + return fn(*args, **extra_kwargs, **kws) except Exception as e: # Normally when we raise a CompilationError, we raise it as -- Gitee