From 3e4abc996c7a1b63c1b2d09ff3d39f15d026e44c Mon Sep 17 00:00:00 2001 From: Daniel Date: Fri, 15 Aug 2025 08:00:11 +0200 Subject: [PATCH] Added f8e5m2 and f8e4m3fn support for the expand ops pass in the arith dialect + test --- .../mlir/Dialect/Arith/Transforms/Passes.h | 6 + .../mlir/Dialect/Arith/Transforms/Passes.td | 4 + .../Dialect/Arith/Transforms/ExpandOps.cpp | 604 +++++++++++++++++- mlir/test/Dialect/Arith/expand-ops.mlir | 42 +- 4 files changed, 642 insertions(+), 14 deletions(-) diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h index 9dc262cc72ed..1be7140ac71d 100644 --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h @@ -57,6 +57,12 @@ void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns); /// Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts. void populateExpandBFloat16Patterns(RewritePatternSet &patterns); +/// Add patterns to expand Arith f8E5M2 patterns to lower level bitcasts/shifts. +void populateExpandF8E5M2Patterns(RewritePatternSet &patterns); + +// Add patterns to expand Arith f8E4M3 patterns to lower level bitcasts/shifts. +void populateExpandF8E4M3FNPatterns(RewritePatternSet &patterns); + /// Add patterns to expand Arith ops. void populateArithExpandOpsPatterns(RewritePatternSet &patterns); diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td index 1517f71f1a7c..eba1ec430550 100644 --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td @@ -17,6 +17,10 @@ def ArithExpandOpsPass : Pass<"arith-expand"> { let options = [ Option<"includeBf16", "include-bf16", "bool", /*default=*/"false", "Enable the BF16 expansion patterns">, + Option<"includeF8E5M2", "include-f8e5m2", "bool", /*default=*/"false", + "Enable the F8E5M2 expansion patterns">, + Option<"includeF8E4M3FN", "include-f8e4m3fn", "bool", /*default=*/"false", + "Enable the F8E4M3FN expansion patterns">, ]; } diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp index 54be644a7101..b328752886e0 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp @@ -335,6 +335,553 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern { } }; +struct F8E5M2ExtFOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::ExtFOp op, + PatternRewriter &rewriter) const final { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + auto operand = op.getOperand(); + Type operandTy = operand.getType(); + Type resultTy = op.getType(); + Type operandETy = getElementTypeOrSelf(operandTy); + Type resultETy = getElementTypeOrSelf(resultTy); + + // Match only f8E5M2 → f32 for now + if (!llvm::isa(operandETy) || !resultETy.isF32()) { + return rewriter.notifyMatchFailure(op, "not a ext of f8E5M2 to f32."); + } + + // Integer and float shaped types matching the input shape + Type i8Ty = b.getI8Type(); + Type i32Ty = b.getI32Type(); + Type f32Ty = b.getF32Type(); + if (auto shapedTy = dyn_cast(operandTy)) { + i8Ty = shapedTy.clone(i8Ty); + i32Ty = shapedTy.clone(i32Ty); + f32Ty = shapedTy.clone(f32Ty); + } + + // Bitcast fp8 to raw uint8 + Value bits = b.create(i8Ty, operand); + // Zero-extend to 32 bits + Value bits32 = b.create(i32Ty, bits); + + // Extract sign (bit 7) → move to f32 sign position (bit 31) + Value sign = b.create( + bits32, createConst(op.getLoc(), i32Ty, 7, rewriter)); + sign = b.create( + sign, createConst(op.getLoc(), i32Ty, 0x1, rewriter)); + + // Extract exponent (bits 2–6) → move to f32 exponent position (bits 23–30) + Value e5m2_exponent = b.create( + bits32, createConst(op.getLoc(), i32Ty, 2, rewriter)); + e5m2_exponent = b.create( + e5m2_exponent, createConst(op.getLoc(), i32Ty, 0x1F, rewriter)); + + // Extract mantissa (bits 0–1) + Value e5m2_mantissa = b.create( + bits32, + createConst(op.getLoc(), i32Ty, 0x3, rewriter)); // 0b11 mask for 2 bits + + // Bias exponent: f8E5M2 has a bias of 15, so we need to subtract 15 + Value exponent = b.create( + e5m2_exponent, createConst(op.getLoc(), i32Ty, 15, rewriter)); + Value float_exponent = b.create( + exponent, createConst(op.getLoc(), i32Ty, 127, rewriter)); + + // Special case handling for NaNs, Infs, subnormals + // Subnormal handling + // if (e5m2_mantissa >= 0x2) + Value isSubnormal = + b.create(arith::CmpIPredicate::sge, e5m2_mantissa, + createConst(op.getLoc(), i32Ty, 0x2, rewriter)); + // result = sign << 31 | (float_exponent) << 23 | (e5m2_mantissa & 0x1) << + // (23 - 1); + Value subnormalResult = b.create( + b.create(sign, + createConst(op.getLoc(), i32Ty, 31, rewriter)), + b.create( + b.create( + float_exponent, createConst(op.getLoc(), i32Ty, 23, rewriter)), + b.create( + b.create( + e5m2_mantissa, + createConst(op.getLoc(), i32Ty, 0x1, rewriter)), + createConst(op.getLoc(), i32Ty, 22, rewriter)))); + + // if (e5m2_mantissa == 0x1) + Value isSubnormal2 = + b.create(arith::CmpIPredicate::eq, e5m2_mantissa, + createConst(op.getLoc(), i32Ty, 0x1, rewriter)); + // result = sign << 31 | (float_exponent - 1) << 23; + Value subnormalResult2 = b.create( + b.create(sign, + createConst(op.getLoc(), i32Ty, 31, rewriter)), + b.create( + b.create( + float_exponent, createConst(op.getLoc(), i32Ty, 1, rewriter)), + createConst(op.getLoc(), i32Ty, 23, rewriter))); + + // Is normal if (e5m2_exponent > 0) + Value isNormal = + b.create(arith::CmpIPredicate::sgt, e5m2_exponent, + createConst(op.getLoc(), i32Ty, 0, rewriter)); + + // else nan + Value NaN = createConst(op.getLoc(), i32Ty, 0x7FC00000, rewriter); + + // Combine sign | exponent | mantissa + Value normalResult = b.create( + b.create(sign, + createConst(op.getLoc(), i32Ty, 31, rewriter)), + b.create( + b.create( + float_exponent, createConst(op.getLoc(), i32Ty, 23, rewriter)), + b.create( + e5m2_mantissa, createConst(op.getLoc(), i32Ty, 21, rewriter)))); + + // Select the appropriate result based on the conditions + Value result = b.create( + isNormal, normalResult, + b.create( + isSubnormal, subnormalResult, + b.create(isSubnormal2, subnormalResult2, NaN))); + + // Bitcast to f32 + result = b.create(f32Ty, result); + + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct F8E5M2TruncFOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::TruncFOp op, + PatternRewriter &rewriter) const final { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + Value operand = op.getOperand(); + Type operandTy = operand.getType(); + Type operandETy = getElementTypeOrSelf(operandTy); + Type resultTy = op.getType(); + Type resultETy = getElementTypeOrSelf(resultTy); + + if (!resultETy.isFloat8E5M2()) { + return rewriter.notifyMatchFailure(op, "not a truncf to fp8e5m2"); + } + + if (op.getRoundingmodeAttr()) { + return rewriter.notifyMatchFailure( + op, "only applicable to default rounding mode."); + } + + Type i8Ty = b.getI8Type(); + Type i32Ty = b.getI32Type(); + Type f32Ty = b.getF32Type(); + + if (auto shapedTy = mlir::dyn_cast(operandTy)) { + i8Ty = shapedTy.clone(i8Ty); + i32Ty = shapedTy.clone(i32Ty); + f32Ty = shapedTy.clone(f32Ty); + } + + // Normalize to f32 + if (operandETy.getIntOrFloatBitWidth() < 32) { + operand = b.create(f32Ty, operand, op.getFastmathAttr()); + } else if (operandETy.getIntOrFloatBitWidth() > 32) { + operand = b.create( + f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr()); + } + + // Bitcast f32 to i32 for bit manipulations + Value bits = b.create(i32Ty, operand); + + // Extract sign bit (bit 31) + Value sign = b.create( + bits, createConst(op.getLoc(), i32Ty, 31, rewriter)); + sign = b.create( + sign, createConst(op.getLoc(), i32Ty, 0x1, rewriter)); + + // Extract exponent bits (bits 30:23) + Value exponent = b.create( + bits, createConst(op.getLoc(), i32Ty, 23, rewriter)); + exponent = b.create( + exponent, createConst(op.getLoc(), i32Ty, 0xFF, rewriter)); + + // Compute unbiased exponent (exponent - 127) + Value exponentBias = createConst(op.getLoc(), i32Ty, 127, rewriter); + Value unbiasedExp = b.create(exponent, exponentBias); + + // Extract mantissa bits (bits 22:0) + Value mantissa = b.create( + bits, createConst(op.getLoc(), i32Ty, 0x7FFFFF, rewriter)); + + // Add fp8 bias (15) + Value fp8Bias = createConst(op.getLoc(), i32Ty, 15, rewriter); + Value fp8Exp = b.create(unbiasedExp, fp8Bias); + + // Prepare mantissa for rounding: + // We need to reduce mantissa from 23 bits → 2 bits mantissa in fp8. + // To round to nearest, shift mantissa right by 21 (23 - 2) + Value mantissaShift = createConst(op.getLoc(), i32Ty, 21, rewriter); + Value mantissaTruncated = b.create(mantissa, mantissaShift); + + Value e5m2_mantissa = b.create( + mantissaTruncated, + createConst(op.getLoc(), i32Ty, 0x3, rewriter)); // 0b11 mask for 2 bits + + // Compose final fp8 bits: sign (bit7), expFinal (bits 6:2), mantissaFinal + // (bits 1:0) + Value signShifted = b.create( + sign, createConst(op.getLoc(), i32Ty, 7, rewriter)); + Value expShifted = b.create( + fp8Exp, createConst(op.getLoc(), i32Ty, 2, rewriter)); + Value resultInt = b.create(signShifted, expShifted); + resultInt = b.create(resultInt, e5m2_mantissa); + + // Subnormal cases + // if (e5m2_exponent > 31) + Value isSubnormal = + b.create(arith::CmpIPredicate::sgt, fp8Exp, + createConst(op.getLoc(), i32Ty, 31, rewriter)); + // return sign << 7 | 0x7C; + Value subnormalResult = b.create( + b.create(sign, + createConst(op.getLoc(), i32Ty, 7, rewriter)), + createConst(op.getLoc(), i32Ty, 0x7C, rewriter) // 0b01111100 + ); + // if ((e5m2_exponent >= -1) && (e5m2_exponent <= 0)) + Value isSubnormal2 = b.create( + b.create(arith::CmpIPredicate::sge, fp8Exp, + createConst(op.getLoc(), i32Ty, -1, rewriter)), + b.create(arith::CmpIPredicate::sle, fp8Exp, + createConst(op.getLoc(), i32Ty, 0, rewriter))); + // uint8_t shift_bits = (2 + e5m2_exponent); + // uint8_t e5m2_mantissa = (mantissa >> (24 - shift_bits)) & (0x3 >> (0 - + // e5m2_exponent)); return sign << 7 | 0x00 | e5m2_mantissa; + + Value shiftBits = b.create( + createConst(op.getLoc(), i32Ty, 2, rewriter), fp8Exp); + Value mantissaShift2 = b.create( + createConst(op.getLoc(), i32Ty, 24, rewriter), shiftBits); + Value e5m2_mantissa2 = b.create( + b.create(mantissa, mantissaShift2), + b.create( + createConst(op.getLoc(), i32Ty, 0x3, + rewriter), // 0b11 mask for 2 bits + b.create( + createConst(op.getLoc(), i32Ty, 0, rewriter), fp8Exp))); + Value subnormalResult2 = b.create( + b.create(sign, + createConst(op.getLoc(), i32Ty, 7, rewriter)), + b.create(createConst(op.getLoc(), i32Ty, 0x00, rewriter), + e5m2_mantissa2)); + + // if (e5m2_exponent < -1) + Value isZero = + b.create(arith::CmpIPredicate::slt, fp8Exp, + createConst(op.getLoc(), i32Ty, -1, rewriter)); + // return sign << 7 | 0x00; + Value zeroResult = b.create( + b.create(sign, + createConst(op.getLoc(), i32Ty, 7, rewriter)), + createConst(op.getLoc(), i32Ty, 0x00, rewriter)); + + // Select the appropiate result based on the conditions + Value finalResult = b.create( + isSubnormal, subnormalResult, + b.create( + isSubnormal2, subnormalResult2, + b.create(isZero, zeroResult, resultInt))); + + // Truncate to i8 and bitcast to fp8e5m2 + Value resultI8 = b.create(i8Ty, finalResult); + Value result = b.create(resultTy, resultI8); + + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct F8E4M3FNExtFOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::ExtFOp op, + PatternRewriter &rewriter) const final { + + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + auto operand = op.getOperand(); + Type operandTy = operand.getType(); + Type resultTy = op.getType(); + Type operandETy = getElementTypeOrSelf(operandTy); + Type resultETy = getElementTypeOrSelf(resultTy); + + // Match only f8E4M3 → f32 for now + if (!llvm::isa(operandETy) || !resultETy.isF32()) { + return rewriter.notifyMatchFailure(op, "not a ext of f8E4M3 to f32."); + } + + // Integer and float shaped types matching the input shape + Type i8Ty = b.getI8Type(); + Type i32Ty = b.getI32Type(); + Type f32Ty = b.getF32Type(); + if (auto shapedTy = dyn_cast(operandTy)) { + i8Ty = shapedTy.clone(i8Ty); + i32Ty = shapedTy.clone(i32Ty); + f32Ty = shapedTy.clone(f32Ty); + } + + // Bitcast fp8 to raw uint8 + Value bits = b.create(i8Ty, operand); + // Zero-extend to 32 bits + Value bits32 = b.create(i32Ty, bits); + + // Extract sign + Value sign = b.create( + bits32, createConst(op.getLoc(), i32Ty, 7, rewriter)); + sign = b.create( + sign, createConst(op.getLoc(), i32Ty, 0x1, rewriter)); + + // extract exponent + Value e4m3_exponent = b.create( + bits32, createConst(op.getLoc(), i32Ty, 3, rewriter)); + e4m3_exponent = b.create( + e4m3_exponent, createConst(op.getLoc(), i32Ty, 0xF, rewriter)); + + // extract mantissa + Value rounding_bias = createConst(op.getLoc(), i32Ty, 0x80000, rewriter); + Value mantissa = b.create(bits32, rounding_bias); + Value e4m3_mantissa = b.create( + mantissa, createConst(op.getLoc(), i32Ty, 0x7, rewriter)); + + // bias exponent + Value exponent = b.create( + e4m3_exponent, createConst(op.getLoc(), i32Ty, 7, rewriter)); + Value float_exponent = b.create( + exponent, createConst(op.getLoc(), i32Ty, 127, rewriter)); + + // put everything together (normal number) e4m3_exponent > 0 + Value isNormal = + b.create(arith::CmpIPredicate::sgt, e4m3_exponent, + createConst(op.getLoc(), i32Ty, 0, rewriter)); + + Value result = b.create( + b.create(sign, + createConst(op.getLoc(), i32Ty, 31, rewriter)), + b.create( + b.create( + float_exponent, createConst(op.getLoc(), i32Ty, 23, rewriter)), + b.create( + e4m3_mantissa, createConst(op.getLoc(), i32Ty, 20, rewriter)))); + + // sub-normal numbers handling (e4m3_matissa >= 0x4) + Value isSubnormal1 = + b.create(arith::CmpIPredicate::sge, e4m3_mantissa, + createConst(op.getLoc(), i32Ty, 0x4, rewriter)); + + Value resultSubnormal1 = b.create( + b.create(sign, + createConst(op.getLoc(), i32Ty, 31, rewriter)), + b.create( + b.create( + float_exponent, createConst(op.getLoc(), i32Ty, 23, rewriter)), + b.create( + createConst(op.getLoc(), i32Ty, 0x3, rewriter), + b.create( + e4m3_mantissa, + createConst(op.getLoc(), i32Ty, 21, rewriter))))); + + // else if e4m3_mantissa > 0x1 + Value isSubnormal2 = + b.create(arith::CmpIPredicate::sgt, e4m3_mantissa, + createConst(op.getLoc(), i32Ty, 0x1, rewriter)); + + Value resultSubormal2 = b.create( + b.create(sign, + createConst(op.getLoc(), i32Ty, 31, rewriter)), + b.create( + b.create( + b.create( + float_exponent, + createConst(op.getLoc(), i32Ty, 1, rewriter)), + createConst(op.getLoc(), i32Ty, 23, rewriter)), + b.create( + createConst(op.getLoc(), i32Ty, 0x1, rewriter), + b.create( + e4m3_mantissa, + createConst(op.getLoc(), i32Ty, 22, rewriter))))); + + // else if e4m3_mantissa == 0x1 + Value isSubnormal3 = + b.create(arith::CmpIPredicate::eq, e4m3_mantissa, + createConst(op.getLoc(), i32Ty, 0x1, rewriter)); + + Value resultSubnormal3 = b.create( + b.create(sign, + createConst(op.getLoc(), i32Ty, 31, rewriter)), + b.create( + b.create( + float_exponent, createConst(op.getLoc(), i32Ty, 2, rewriter)), + createConst(op.getLoc(), i32Ty, 23, rewriter))); + + // else Zero + Value resultZero = b.create( + sign, createConst(op.getLoc(), i32Ty, 31, rewriter)); + + // Compute final result + result = b.create( + isNormal, result, + b.create( + isSubnormal1, resultSubnormal1, + b.create( + isSubnormal2, resultSubormal2, + b.create(isSubnormal3, resultSubnormal3, + resultZero)))); + + // Bitcast to f32 + result = b.create(f32Ty, result); + + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct F32ToF8E4M3FNTruncFOpConverter + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::TruncFOp op, + PatternRewriter &rewriter) const final { + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + auto operand = op.getOperand(); + Type operandTy = operand.getType(); + Type resultTy = op.getType(); + Type operandETy = getElementTypeOrSelf(operandTy); + Type resultETy = getElementTypeOrSelf(resultTy); + + // Match only f32 → f8E4M3 + if (!operandETy.isF32() || !llvm::isa(resultETy)) { + return rewriter.notifyMatchFailure(op, "not a trunc of f32 to f8E4M3."); + } + + // Integer and float shaped types matching the input shape + Type i8Ty = b.getI8Type(); + Type i32Ty = b.getI32Type(); + if (auto shapedTy = dyn_cast(operandTy)) { + i8Ty = shapedTy.clone(i8Ty); + i32Ty = shapedTy.clone(i32Ty); + } + + // Bitcast f32 to raw uint32 + Value bits32 = b.create(i32Ty, operand); + + // Constants + Value bias127 = createConst(op.getLoc(), i32Ty, 127, rewriter); + Value bias7 = createConst(op.getLoc(), i32Ty, 7, rewriter); + + // Extract sign + Value sign = b.create( + bits32, createConst(op.getLoc(), i32Ty, 31, rewriter)); + sign = b.create( + sign, createConst(op.getLoc(), i32Ty, 0x1, rewriter)); + + // Extract exponent + Value exponent = b.create( + bits32, createConst(op.getLoc(), i32Ty, 23, rewriter)); + exponent = b.create( + exponent, createConst(op.getLoc(), i32Ty, 0xFF, rewriter)); + exponent = b.create(exponent, bias127); + + // Extract the mantissa + Value mantissa = b.create( + bits32, createConst(op.getLoc(), i32Ty, 0x7FFFFF, rewriter)); + + // For normal numbers, add the implicit leading 1 in the mantissa + mantissa = b.create( + mantissa, createConst(op.getLoc(), i32Ty, 0x800000, rewriter)); + + // Apply the bias for e4m3 (bias of 7) + Value e4m3_exponent = b.create(exponent, bias7); + + // if e4m3_exponent > 15 + Value isOverflow = + b.create(arith::CmpIPredicate::sgt, e4m3_exponent, + createConst(op.getLoc(), i32Ty, 15, rewriter)); + + // Clamp to max finite value + Value maxFinite = + createConst(op.getLoc(), i32Ty, 0x7F, rewriter); // 0b01111111 in f8 + + // if ((e4m3_exponent > -3) && (e4m3_exponent <= 0)) + Value isSubnormal = + b.create(arith::CmpIPredicate::sge, e4m3_exponent, + createConst(op.getLoc(), i32Ty, -3, rewriter)); + isSubnormal = b.create( + isSubnormal, + b.create(arith::CmpIPredicate::sle, e4m3_exponent, + createConst(op.getLoc(), i32Ty, 0, rewriter))); + + Value shift_bits = b.create( + e4m3_exponent, createConst(op.getLoc(), i32Ty, 3, rewriter)); + Value e4m3_mantissa_subnormal = b.create( + mantissa, + b.create(createConst(op.getLoc(), i32Ty, 24, rewriter), + shift_bits)); + e4m3_mantissa_subnormal = b.create( + e4m3_mantissa_subnormal, + b.create( + createConst(op.getLoc(), i32Ty, 0x7, rewriter), + b.create( + createConst(op.getLoc(), i32Ty, 0, rewriter), e4m3_exponent))); + + Value resultSubnormal = b.create( + b.create(sign, + createConst(op.getLoc(), i32Ty, 7, rewriter)), + b.create( + createConst(op.getLoc(), i32Ty, 0x00, + rewriter), // Exponent is 0 for subnormals + e4m3_mantissa_subnormal)); + + // else if e4m3_exponent <= -3 + Value isZero = + b.create(arith::CmpIPredicate::sle, e4m3_exponent, + createConst(op.getLoc(), i32Ty, -3, rewriter)); + + Value resultZero = + createConst(op.getLoc(), i32Ty, 0x00, rewriter); // 0b00000000 + + // For normal numbers, normalize mantissa to fit into 3 bits (e4m3 has 3 + // bits for mantissa) + Value e4m3_mantissa = b.create( + mantissa, createConst(op.getLoc(), i32Ty, 20, rewriter)); + e4m3_mantissa = b.create( + e4m3_mantissa, createConst(op.getLoc(), i32Ty, 0x7, rewriter)); + + // Pack the sign, exponent, and mantissa into an 8-bit value (normal + // numbers) + Value result = b.create( + b.create(sign, + createConst(op.getLoc(), i32Ty, 7, rewriter)), + b.create( + b.create( + e4m3_exponent, createConst(op.getLoc(), i32Ty, 3, rewriter)), + e4m3_mantissa)); + + // compute final result (if no codition is met, result is normal) + result = b.create( + isOverflow, maxFinite, + b.create( + isSubnormal, resultSubnormal, + b.create(isZero, resultZero, result))); + + // Truncate to i8 and bitcast to f8e4m3 + result = b.create(i8Ty, result); + result = b.create(resultTy, result); + + rewriter.replaceOp(op, result); + return success(); + } +}; + struct ArithExpandOpsPass : public arith::impl::ArithExpandOpsPassBase { using ArithExpandOpsPassBase::ArithExpandOpsPassBase; @@ -363,21 +910,42 @@ struct ArithExpandOpsPass if (includeBf16) { arith::populateExpandBFloat16Patterns(patterns); - target.addDynamicallyLegalOp( - [](arith::ExtFOp op) { - Type inETy = getElementTypeOrSelf(op.getOperand().getType()); - Type outETy = getElementTypeOrSelf(op.getType()); - return !(inETy.isBF16() && outETy.isF32()); - }); - - target.addDynamicallyLegalOp( - [](arith::TruncFOp op) { - Type inETy = getElementTypeOrSelf(op.getOperand().getType()); - Type outETy = getElementTypeOrSelf(op.getType()); - return !(inETy.isF32() && outETy.isBF16()); - }); + } + if (includeF8E5M2){ + arith::populateExpandF8E5M2Patterns(patterns); + } + if (includeF8E4M3FN){ + arith::populateExpandF8E4M3FNPatterns(patterns); } + target.addDynamicallyLegalOp( + [=](arith::ExtFOp op) { + Type inETy = getElementTypeOrSelf(op.getOperand().getType()); + Type outETy = getElementTypeOrSelf(op.getType()); + bool legalTypes = true; + if (includeBf16) + legalTypes &= !(inETy.isBF16() && outETy.isF32()); + if (includeF8E5M2) + legalTypes &= !inETy.isFloat8E5M2(); + if (includeF8E4M3FN) + legalTypes &= !inETy.isFloat8E4M3FN(); + return legalTypes; + }); + + target.addDynamicallyLegalOp( + [=](arith::TruncFOp op) { + Type inETy = getElementTypeOrSelf(op.getOperand().getType()); + Type outETy = getElementTypeOrSelf(op.getType()); + bool legalTypes = true; + if (includeBf16) + legalTypes &= !(inETy.isF32() && outETy.isBF16()); + if (includeF8E5M2) + legalTypes &= !outETy.isFloat8E5M2(); + if (includeF8E4M3FN) + legalTypes &= !outETy.isFloat8E4M3FN(); + return legalTypes; + }); + // clang-format on if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) @@ -399,6 +967,16 @@ void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) { patterns.getContext()); } +void mlir::arith::populateExpandF8E5M2Patterns(RewritePatternSet &patterns) { + patterns.add( + patterns.getContext()); +} + +void mlir::arith::populateExpandF8E4M3FNPatterns(RewritePatternSet &patterns) { + patterns.add( + patterns.getContext()); +} + void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) { populateCeilFloorDivExpandOpsPatterns(patterns); // clang-format off diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir index 174eb468cc00..b4d70abf0843 100644 --- a/mlir/test/Dialect/Arith/expand-ops.mlir +++ b/mlir/test/Dialect/Arith/expand-ops.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -arith-expand="include-bf16=true" -split-input-file | FileCheck %s +// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e5m2=true include-f8e4m3fn=true" -split-input-file | FileCheck %s // Test ceil divide with signed integer // CHECK-LABEL: func @ceildivi @@ -310,3 +310,43 @@ func.func @minui(%a: i32, %b: i32) -> i32 { // CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LHS]], %[[RHS]] : i32 // CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32 // CHECK-NEXT: return %[[RESULT]] : i32 + +// ----- + +func.func @extf_vector_f8E5M2_to_f32(%arg0 : vector<4xf8E5M2>) -> vector<4xf32> { + %0 = arith.extf %arg0 : vector<4xf8E5M2> to vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: @extf_vector_f8E5M2_to_f32 +// CHECK-NOT: arith.extf + +// ----- + +func.func @truncf_vector_f32_to_f8E5M2(%arg0 : vector<4xf32>) -> vector<4xf8E5M2> { + %0 = arith.truncf %arg0 : vector<4xf32> to vector<4xf8E5M2> + return %0 : vector<4xf8E5M2> +} + +// CHECK-LABEL: @truncf_vector_f32_to_f8E5M2 +// CHECK-NOT: arith.truncf + +// ----- + +func.func @extf_vector_f8E4M3FN_to_f32(%arg0 : vector<4xf8E4M3FN>) -> vector<4xf32> { + %0 = arith.extf %arg0 : vector<4xf8E4M3FN> to vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: @extf_vector_f8E4M3FN_to_f32 +// CHECK-NOT: arith.extf + +// ----- + +func.func @truncf_vector_f32_to_f8E4M3FN(%arg0 : vector<4xf32>) -> vector<4xf8E4M3FN> { + %0 = arith.truncf %arg0 : vector<4xf32> to vector<4xf8E4M3FN> + return %0 : vector<4xf8E4M3FN> +} + +// CHECK-LABEL: @truncf_vector_f32_to_f8E4M3FN +// CHECK-NOT: arith.truncf -- Gitee