Skip to content

Commit 3fe55f2

Browse files
committed
[TorchToArith] Implement conversion patterns for AtenNegFloatOp and add ConvertAtenNegIntOp test case.
Implement conversion patterns for `AtenNegFloatOp`: arith::subf(0.0, a); Add `ConvertAtenNegIntOp` test case.
1 parent 42c3c29 commit 3fe55f2

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

lib/Conversion/TorchToArith/TorchToArith.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,27 @@ class ConvertAtenNegIntOp : public OpConversionPattern<AtenNegIntOp> {
110110
};
111111
} // namespace
112112

113+
namespace {
114+
class ConvertAtenNegFloatOp : public OpConversionPattern<AtenNegFloatOp> {
115+
public:
116+
using OpConversionPattern<AtenNegFloatOp>::OpConversionPattern;
117+
LogicalResult matchAndRewrite(
118+
AtenNegFloatOp op,
119+
typename OpConversionPattern<AtenNegFloatOp>::OpAdaptor adaptor,
120+
ConversionPatternRewriter &rewriter) const override {
121+
Value a = adaptor.getA();
122+
Type inputDtype = a.getType();
123+
rewriter.replaceOpWithNewOp<arith::SubFOp>(
124+
op,
125+
arith::ConstantOp::create(
126+
rewriter, op.getLoc(),
127+
rewriter.getFloatAttr(inputDtype, /*value=*/0.0)),
128+
a);
129+
return success();
130+
}
131+
};
132+
} // namespace
133+
113134
namespace {
114135
template <typename AtenOp, typename UnaryOp>
115136
class ConvertAtenUnaryOpToFloatMathOp : public OpConversionPattern<AtenOp> {
@@ -501,6 +522,9 @@ class ConvertTorchToArith
501522
patterns.add<ConvertAtenAddOp>(typeConverter, context);
502523
target.addIllegalOp<AtenNegIntOp>();
503524
patterns.add<ConvertAtenNegIntOp>(typeConverter, context);
525+
target.addIllegalOp<AtenNegFloatOp>();
526+
patterns.add<ConvertAtenNegFloatOp>(typeConverter, context);
527+
504528
target.addIllegalOp<AtenAddIntOp, AtenAddFloatIntOp, AtenSubIntOp,
505529
AtenMulIntOp, AtenRemainderIntOp, AtenMulIntFloatOp,
506530
AtenMulFloatIntOp>();

test/Conversion/TorchToArith/basic.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,3 +407,27 @@ func.func @torch.aten.Int.Scalar(%arg0: !torch.float) -> !torch.int {
407407
%0 = torch.aten.Int.Scalar %arg0 : !torch.float -> !torch.int
408408
return %0 : !torch.int
409409
}
410+
411+
// CHECK-LABEL: func.func @torch.aten.neg.int(
412+
// CHECK-SAME: %[[ARG:.*]]: !torch.int) -> !torch.int {
413+
// CHECK: %[[ARG_I64:.*]] = torch_c.to_i64 %[[ARG]]
414+
// CHECK: %[[CST:.*]] = arith.constant 0 : i64
415+
// CHECK: %[[SUB:.*]] = arith.subi %[[CST:.*]], [[ARG_I64:.*]] : i64
416+
// CHECK: %[[OUT:.*]] = torch_c.from_i64 %[[SUB:.*]]
417+
// CHECK: return %[[OUT:.*]] : !torch.int
418+
func.func @torch.aten.neg.int(%arg0: !torch.int) -> !torch.int {
419+
%0 = torch.aten.neg.int %arg0 : !torch.int -> !torch.int
420+
return %0 : !torch.int
421+
}
422+
423+
// CHECK-LABEL: func.func @torch.aten.neg.float(
424+
// CHECK-SAME: %[[ARG:.*]]: !torch.float) -> !torch.float {
425+
// CHECK: %[[ARG_F64:.*]] = torch_c.to_f64 %[[ARG]]
426+
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f64
427+
// CHECK: %[[SUB:.*]] = arith.subf %[[CST:.*]], [[ARG_F64:.*]] : f64
428+
// CHECK: %[[OUT:.*]] = torch_c.from_f64 %[[SUB:.*]]
429+
// CHECK: return %[[OUT:.*]] : !torch.float
430+
func.func @torch.aten.neg.float(%arg0: !torch.float) -> !torch.float {
431+
%0 = torch.aten.neg.float %arg0 : !torch.float -> !torch.float
432+
return %0 : !torch.float
433+
}

0 commit comments

Comments
 (0)