diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td index b946fc8875860b9245011d8141ca8f3cb2063bc0..5caa0932c73eb579e8dec1deec1689761f1f463e 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -1379,4 +1379,57 @@ def YieldOp : TransformDialectOp<"yield", ]; } +def LowerToArmSMEOp : TransformDialectOp<"lower_to_arm_sme", + [FunctionalStyleTransformOpTrait, + DeclareOpInterfaceMethods, + TransformOpInterface, TransformEachOpTrait]> { + let description = [{Apply a list of passes to lower supported ops to + legalized arm_sme dialect ops and types.}]; + + let arguments = + (ins TransformHandleTypeInterface:$target, + DefaultValuedAttr:$fuse_outer_products + ); + + let assemblyFormat = "$target attr-dict `:` type($target)"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::ModuleOp target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +def LowerToLLVMNewOp : TransformDialectOp<"lower_to_llvm_new", + [FunctionalStyleTransformOpTrait, + DeclareOpInterfaceMethods, + TransformOpInterface, TransformEachOpTrait]> { + let description = [{Indicates that the entire module should be converted + to the LLVM dialect. This is expected to be the last transformation in + a sequence.}]; + + let arguments = + (ins TransformHandleTypeInterface:$target, + DefaultValuedAttr:$reassociate_fp_reductions, + DefaultValuedAttr:$enable_index_optimizations, + DefaultValuedAttr:$enable_arm_neon, + DefaultValuedAttr:$enable_arm_sve, + DefaultValuedAttr:$enable_amx, + DefaultValuedAttr:$enable_x86vector, + DefaultValuedAttr:$enable_async, + DefaultValuedAttr:$vscale_range); + + let assemblyFormat = "$target attr-dict `:` type($target)"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::ModuleOp target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + #endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index c4238080533bef11c66aea0fdc1ee8665ff11fde..ba06950ea22503bb738842d1ea03616beb8a417c 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/Dialect/ArmSME/IR/ArmSME.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dominance.h" @@ -41,6 +42,11 @@ #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" + +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/Dialect/ArmSME/Transforms/Passes.h" +#include "mlir/InitAllPasses.h" + #include #define DEBUG_TYPE "transform-dialect" @@ -2859,3 +2865,131 @@ void transform::YieldOp::getEffects( SmallVectorImpl &effects) { onlyReadsHandle(getOperandsMutable(), effects); } + +//===----------------------------------------------------------------------===// +// LowerToArmSMEOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::LowerToArmSMEOp::applyToOne( + transform::TransformRewriter &rewriter, ModuleOp target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + PassManager pm(getContext()); + // createVectorLegalizationPass requires ModuleOp level pass. + // Legalize vector operations so they can be converted to ArmSME. + pm.addPass(arm_sme::createVectorLegalizationPass()); + + // Sprinkle some cleanups. + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); + + // Passes that convert operations on vectors to ArmSME operations. + pm.addPass(createArithToArmSMEConversionPass()); + pm.addPass(createConvertVectorToArmSMEPass()); + + // TODO: Leverage FMOPA 2Way for half precision? + // Fuse outer products. + if (getFuseOuterProducts()) + pm.addPass(arm_sme::createOuterProductFusionPass()); + + // Convert operations on high-level vectors to loops. + pm.addPass(createConvertArmSMEToSCFPass()); + // Convert Vector to SCF (with full unroll enabled). + pm.addNestedPass(arm_sme::createEnableArmStreamingPass( + arm_sme::ArmStreamingMode::StreamingLocally, arm_sme::ArmZaMode::NewZA, + /*onlyIfRequiredByOps=*/true)); + + if (failed(pm.run(target))) + return DiagnosedSilenceableFailure::definiteFailure(); + return DiagnosedSilenceableFailure::success(); +} + +void transform::LowerToArmSMEOp::getEffects( + SmallVectorImpl &effects) { + transform::modifiesPayload(effects); + transform::onlyReadsHandle(getTargetMutable(), effects); +} + +//===---------------------------------------------------------------------===// +// LowerToLLVMNewOp +//===---------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::LowerToLLVMNewOp::applyToOne( + transform::TransformRewriter &rewriter, ModuleOp target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + // TODO: it is feasible to scope lowering at arbitrary level and introduce + // unrealized casts, but there needs to be the final module-wise cleanup in + // the end. Keep module-level for now. + MLIRContext *ctx = getContext(); + PassManager pm(ctx); + + // Lower multi dimensionOps to scf + pm.addNestedPass(createConvertVectorToSCFPass()); + pm.addNestedPass(createConvertLinalgToLoopsPass()); + // Lower Async + if (getEnableAsync()) { + pm.addPass(createAsyncToAsyncRuntimePass()); + pm.addPass(createAsyncRuntimeRefCountingPass()); + pm.addPass(createAsyncRuntimeRefCountingOptPass()); + } + pm.addPass(createCanonicalizerPass()); + pm.addPass(memref::createExpandStridedMetadataPass()); + // The expansion may create affine expressions. Get rid of them. + pm.addPass(createLowerAffinePass()); + pm.addPass(createConvertSCFToCFPass()); + if (ctx->getLoadedDialect()) { + pm.addNestedPass(createConvertArmSMEToLLVMPass()); + } + pm.addPass(createConvertComplexToLLVMPass()); + pm.addPass(createConvertVectorToLLVMPass(ConvertVectorToLLVMPassOptions{ + /* reassociateFPReductions = */ getReassociateFpReductions(), + /* force32BitVectorIndices */ getEnableIndexOptimizations(), + /* amx = */ getEnableAmx(), + /* armNeon = */ getEnableArmNeon(), + /* armSVE = */ getEnableArmSve(), + /* x86Vector = */ getEnableX86vector()})); + pm.addNestedPass(createConvertMathToLLVMPass()); + pm.addNestedPass(arith::createArithExpandOpsPass()); + pm.addPass(createFinalizeMemRefToLLVMConversionPass()); + if (getEnableAsync()) + pm.addPass(createConvertAsyncToLLVMPass()); + pm.addPass(createConvertOpenMPToLLVMPass()); + pm.addPass(createConvertFuncToLLVMPass()); + pm.addPass(createConvertControlFlowToLLVMPass()); + pm.addPass(createArithToLLVMConversionPass()); + pm.addPass(createConvertIndexToLLVMPass()); + pm.addPass(createReconcileUnrealizedCastsPass()); + if (failed(pm.run(target))) + return DiagnosedSilenceableFailure::definiteFailure(); + + llvm::SmallVector attrs; + + if (getVscaleRange() > 0) { + attrs.push_back(mlir::ArrayAttr::get( + ctx, {mlir::StringAttr::get(ctx, "vscale_range"), + mlir::StringAttr::get(ctx, llvm::Twine(getVscaleRange()))})); + + target->walk([&](LLVM::LLVMFuncOp funcOp) { + if (!funcOp.getBody().empty()) + funcOp->setAttr("passthrough", mlir::ArrayAttr::get(ctx, attrs)); + }); + } + + // Make all arguments noalias for now. + // FIXME: this is a terrible hack! + target->walk([](LLVM::LLVMFuncOp funcOp) { + for (int64_t i = 0; i < funcOp.getNumArguments(); ++i) { + if (!isa(funcOp.getFunctionType().getParamType(i))) + continue; + funcOp.setArgAttr(i, "llvm.noalias", UnitAttr::get(funcOp.getContext())); + } + }); + return DiagnosedSilenceableFailure::success(); +} + +void transform::LowerToLLVMNewOp::getEffects( + SmallVectorImpl &effects) { + transform::modifiesPayload(effects); + transform::onlyReadsHandle(getTargetMutable(), effects); +} \ No newline at end of file diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul2.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul2.mlir new file mode 100644 index 0000000000000000000000000000000000000000..8ac11c9d350de56a72878eccea760cd27845ae7e --- /dev/null +++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul2.mlir @@ -0,0 +1,62 @@ +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.consumed}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op + + %tiled_linalg_op, %loops:3 = transform.structured.tile_using_for %0 tile_sizes [[4], [4], 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + + transform.structured.vectorize %tiled_linalg_op vector_sizes [[4], [4], 1] : !transform.any_op + + %1 = transform.bufferization.one_shot_bufferize %arg0 {bufferize_function_boundaries = true} : (!transform.any_op) -> !transform.any_op + + %2 = transform.structured.match ops{["func.func"]} in %1 : (!transform.any_op) -> !transform.any_op + + %3 = transform.apply_registered_pass "convert-linalg-to-loops" to %2 : (!transform.any_op) -> !transform.op<"func.func"> + + transform.apply_patterns to %3 { + transform.apply_patterns.vector.lower_masked_transfers + transform.apply_patterns.vector.transfer_permutation_patterns + transform.apply_patterns.vector.reduction_to_contract + } : !transform.op<"func.func"> + + transform.apply_patterns to %3 { + transform.apply_patterns.vector.cast_away_vector_leading_one_dim + transform.apply_patterns.tensor.fold_tensor_subset_ops_into_vector_transfers + transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct" + transform.apply_patterns.vector.lower_masks + transform.apply_patterns.canonicalization + } : !transform.op<"func.func"> + + %5 = transform.structured.match interface{LoopLikeInterface} in %1 : (!transform.any_op) -> !transform.any_op + + transform.apply_licm to %5 : !transform.any_op + + transform.loop.hoist_loop_invariant_subsets %5 : !transform.any_op + + transform.yield + } + transform.named_sequence @arm_sme_lowering_schedule(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op { + transform.lower_to_arm_sme %arg0 : !transform.any_op + transform.yield %arg0 : !transform.any_op + } + transform.named_sequence @lower_to_llvm_schedule(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op { + transform.lower_to_llvm_new %arg0 {enable_arm_sve = true, enable_index_optimizations = true, vscale_range = 0 : i64} : !transform.any_op + transform.yield %arg0 : !transform.any_op + } + transform.named_sequence @__transform_main_next(%arg0: !transform.any_op) { + %0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {deduplicate} : (!transform.any_op) -> !transform.any_op + %2 = transform.include @arm_sme_lowering_schedule failures(propagate) (%1) : (!transform.any_op) -> !transform.any_op + %3 = transform.apply_registered_pass "cse" to %2 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %3 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.scf.for_loop_canonicalization + } : !transform.any_op + %4 = transform.structured.match interface{LoopLikeInterface} in %3 : (!transform.any_op) -> !transform.any_op + transform.apply_licm to %4 : !transform.any_op + %5 = transform.structured.match ops{["func.func"]} in %3 : (!transform.any_op) -> !transform.any_op + %6 = transform.structured.hoist_redundant_vector_transfers %5 : (!transform.any_op) -> !transform.any_op + %7 = transform.structured.hoist_redundant_vector_broadcasts %6 : (!transform.any_op) -> !transform.any_op + %8 = transform.apply_registered_pass "canonicalize" to %7 : (!transform.any_op) -> !transform.any_op + transform.yield + } +}