Commit 9a01762a authored by Pruthvi's avatar Pruthvi Committed by Robert Kimball

[MLIR] Affine lowering support for Comparison Ops (#3872)

* Affine lowering support for
    1) GreaterEq
    2) LessEq
    3) Equal
    4) NotEqual

* - cast result op from i1 to i8 for comparision operators

* Addressed PR comments

* Style fix

* - style check
- use select instead of zero_extendi durinng CompOp lowering

* - fix style
- use createOneConstant and createZeroConstant helpers in select intrinsic

* Use NG_U8_TYPE_ID for BooleanType in ngraph dialect

* Diable CE and Softmax unit test in MLIR

* - LIT parser test for comparision ops

* - Affine dailect LIT tests for Comparision Ops

* Address PR feedback

* fix typo

* - use `cast` to deduce element Type
- add more strict type checking to LIT Test

* fix CHECK label's for comparision ops

* Use UInt8 in verification logic for CMP op's

* - use UInt8 for the resultOp verification Logic in CMP op
- fix unit test failures
parent 7bb94ca0
......@@ -163,6 +163,7 @@ namespace
DialectLoweringPass& pass);
ValueHandle createZeroConstant(mlir::Type type);
ValueHandle createOneConstant(mlir::Type type);
/// Conversion from types in the nGraph dialect to the Standard dialect.
class NGraphTypeConverter : public TypeConverter
......@@ -538,6 +539,30 @@ namespace
return matchSuccess();
}
REWRITER(NGGreaterEqOp)
{
lowerBinaryElementwise<mlir::NGGreaterEqOp>(op, operands, rewriter, pass);
return matchSuccess();
}
REWRITER(NGLessEqOp)
{
lowerBinaryElementwise<mlir::NGLessEqOp>(op, operands, rewriter, pass);
return matchSuccess();
}
REWRITER(NGEqOp)
{
lowerBinaryElementwise<mlir::NGEqOp>(op, operands, rewriter, pass);
return matchSuccess();
}
REWRITER(NGNotEqOp)
{
lowerBinaryElementwise<mlir::NGNotEqOp>(op, operands, rewriter, pass);
return matchSuccess();
}
REWRITER(NGMaxOp)
{
lowerBinaryElementwise<mlir::NGMaxOp>(op, operands, rewriter, pass);
......@@ -1248,6 +1273,8 @@ namespace
auto pivs = makeHandlePointers(MutableArrayRef<IndexHandle>(ivs));
// Steps
auto steps = vLHS.getSteps();
// element type of the operand
Type elemTy = result->getType().cast<MemRefType>().getElementType();
AffineLoopNestBuilder(pivs, lbs, ubs, steps)(
// single stmt body
[&] {
......@@ -1267,13 +1294,52 @@ namespace
{
iRes(ivs) = iLHS(ivs) / iRHS(ivs);
}
// TODO(pthoreho) For all comparision operators, use
// edsc::intrinsics::zero_extendi(ValueHandle(iLHS(ivs)) !=
// ValueHandle(iRHS(ivs)), IntegerType::get(8, op->getContext()));
// instead of edsc::intrinsics::select once `zero_extendi` is
// made available in the edsc::intrinsics namescope in MLIR repo.
else if (isa<NGGreaterOp>(op))
{
iRes(ivs) = ValueHandle(iLHS(ivs)) > ValueHandle(iRHS(ivs));
iRes(ivs) =
edsc::intrinsics::select(ValueHandle(iLHS(ivs)) > ValueHandle(iRHS(ivs)),
createOneConstant(elemTy),
createZeroConstant(elemTy));
}
else if (isa<NGLessOp>(op))
{
iRes(ivs) = ValueHandle(iLHS(ivs)) < ValueHandle(iRHS(ivs));
iRes(ivs) =
edsc::intrinsics::select(ValueHandle(iLHS(ivs)) < ValueHandle(iRHS(ivs)),
createOneConstant(elemTy),
createZeroConstant(elemTy));
}
else if (isa<NGGreaterEqOp>(op))
{
iRes(ivs) =
edsc::intrinsics::select(ValueHandle(iLHS(ivs)) >= ValueHandle(iRHS(ivs)),
createOneConstant(elemTy),
createZeroConstant(elemTy));
}
else if (isa<NGLessEqOp>(op))
{
iRes(ivs) =
edsc::intrinsics::select(ValueHandle(iLHS(ivs)) <= ValueHandle(iRHS(ivs)),
createOneConstant(elemTy),
createZeroConstant(elemTy));
}
else if (isa<NGEqOp>(op))
{
iRes(ivs) =
edsc::intrinsics::select(ValueHandle(iLHS(ivs)) == ValueHandle(iRHS(ivs)),
createOneConstant(elemTy),
createZeroConstant(elemTy));
}
else if (isa<NGNotEqOp>(op))
{
iRes(ivs) =
edsc::intrinsics::select(ValueHandle(iLHS(ivs)) != ValueHandle(iRHS(ivs)),
createOneConstant(elemTy),
createZeroConstant(elemTy));
}
else if (isa<NGMaxOp>(op))
{
......@@ -1413,6 +1479,30 @@ namespace
}
NGRAPH_UNREACHABLE("Unsupported type");
}
ValueHandle createOneConstant(mlir::Type type)
{
if (auto floatTy = type.dyn_cast<FloatType>())
{
if (floatTy.isF32())
{
return intrinsics::constant_float(llvm::APFloat(1.0f), floatTy);
}
else if (floatTy.isF64())
{
return intrinsics::constant_float(llvm::APFloat(1.0f), floatTy);
}
else
{
NGRAPH_UNREACHABLE("Unsupported floating-point precision");
}
}
else if (auto intTy = type.dyn_cast<IntegerType>())
{
return intrinsics::constant_int(1, intTy.getWidth());
}
NGRAPH_UNREACHABLE("Unsupported type");
}
} // namespace
namespace mlir
......
......@@ -34,6 +34,10 @@ MLIR_OP(NGDotOp , false )
MLIR_OP(NGGatherOp , false )
MLIR_OP(NGGreaterOp , true )
MLIR_OP(NGLessOp , true )
MLIR_OP(NGGreaterEqOp , true )
MLIR_OP(NGLessEqOp , true )
MLIR_OP(NGEqOp , true )
MLIR_OP(NGNotEqOp , true )
MLIR_OP(NGMulOp , true )
MLIR_OP(NGMaxOp , true )
MLIR_OP(NGMinOp , true )
......
......@@ -171,8 +171,11 @@ static mlir::LogicalResult verifyCmpOp(T* op)
NGTensorType resType = r0.cast<NGTensorType>();
// result of same shape as input and has bool type
if (!resType.isCompatibleShape(opType0) || !resType.getElementType().isa<NGBoolType>())
if (!resType.isCompatibleShape(opType0) ||
!resType.getElementType().cast<NGIntegerType>().isUInt8())
{
return op->emitOpError("Incompatible result shape or type for comparison op");
}
return mlir::success();
}
......
......@@ -13,6 +13,10 @@ MLIR_OP(Convolution)
MLIR_OP(Gather)
MLIR_OP(Greater)
MLIR_OP(Less)
MLIR_OP(GreaterEq)
MLIR_OP(LessEq)
MLIR_OP(Equal)
MLIR_OP(NotEqual)
MLIR_OP(Maximum)
MLIR_OP(Minimum)
MLIR_OP(Multiply)
......
......@@ -28,15 +28,19 @@
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/equal.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/less_eq.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/subtract.hpp"
......@@ -473,18 +477,6 @@ void MLIRSubgraphExtractionPass::sanity_check(std::shared_ptr<Function> func, No
bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node)
{
// Disable any op using boolean type until we have support for i1<->i8 conversion in MLIR.
// Otherwise, we would generate code like this:
// %0 = icmp %a, %b : i1
// store %0, %c[%arg1] : i8 // Type error: trying to store an i1 into an i8.
for (auto& output : node->get_outputs())
{
if (output.get_element_type() == element::boolean)
{
return false;
}
}
if (TI(Parameter) == TI(*node) || TI(Result) == TI(*node))
{
return true;
......
......@@ -33,14 +33,18 @@
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/equal.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/less_eq.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/index_reduction.hpp"
......@@ -358,6 +362,29 @@ mlir::Operation* NgDialectConversionPass::COMPILE_OP_DECL(ngraph::op::Less)
return NgDialectObj.createGenericOp<mlir::NGLessOp>(ngNode);
}
template <>
mlir::Operation* NgDialectConversionPass::COMPILE_OP_DECL(ngraph::op::GreaterEq)
{
return NgDialectObj.createGenericOp<mlir::NGGreaterEqOp>(ngNode);
}
template <>
mlir::Operation* NgDialectConversionPass::COMPILE_OP_DECL(ngraph::op::LessEq)
{
return NgDialectObj.createGenericOp<mlir::NGLessEqOp>(ngNode);
}
template <>
mlir::Operation* NgDialectConversionPass::COMPILE_OP_DECL(ngraph::op::Equal)
{
return NgDialectObj.createGenericOp<mlir::NGEqOp>(ngNode);
}
template <>
mlir::Operation* NgDialectConversionPass::COMPILE_OP_DECL(ngraph::op::NotEqual)
{
return NgDialectObj.createGenericOp<mlir::NGNotEqOp>(ngNode);
}
template <>
mlir::Operation* NgDialectConversionPass::COMPILE_OP_DECL(ngraph::op::Maximum)
{
......
......@@ -830,6 +830,11 @@ TEST(core_fusion, softmax_crossentropy_bprop_with_ignore_mask)
}
#endif
// TODO(pthoreho): MLIR currently does not support all the op's needed for CrossEntropy+Softmax
// this results in multiple CompiledKernels and we cannot able to safely check for certain op's
// from the function created by user.
// Note: remove this guards once we have full support for CE and Softmax through MLIR
void test_softmax_crossentropy(Shape input_shape,
Shape label_shape,
bool soft_label,
......@@ -864,7 +869,7 @@ void test_softmax_crossentropy(Shape input_shape,
}
}
TEST(core_fusion, softmax_crossentropy)
TEST(core_fusion, MLIR_DISABLE_TEST(softmax_crossentropy))
{
test_softmax_crossentropy(Shape{41, 37}, Shape{41, 37}, true, -1);
test_softmax_crossentropy(Shape{41, 37}, Shape{41, 1}, false, 5);
......@@ -901,7 +906,7 @@ void test_crossentropy(Shape input_shape, Shape label_shape, bool soft_label, in
}
}
TEST(core_fusion, crossentropy)
TEST(core_fusion, MLIR_DISABLE_TEST(crossentropy))
{
test_crossentropy(Shape{41, 37}, Shape{41, 37}, true, -1);
test_crossentropy(Shape{41, 37}, Shape{41, 1}, false, 5);
......
......@@ -5,12 +5,13 @@
// -----
// Gather Op
// CHECK: affine.for %[[I:.*]] = 0 to 16 {
// CHECK: %[[L0:.*]] = affine.load %{{.*}}[%[[I]]]
// CHECK: %[[GATHER_IDX:.*]] = index_cast %[[L0]]
// CHECK: affine.for %[[J:.*]] = 0 to 32 {
// CHECK: %[[VALUE:.*]] = load %{{.*}}[%[[GATHER_IDX]], %[[J]]]
// CHECK: affine.store %[[VALUE]], {{.*}}[%[[I]], %[[J]]]
// CHECK-LABEL: func @simple_gather
// CHECK: affine.for %[[I:.*]] = 0 to 16 {
// CHECK: %[[L0:.*]] = affine.load %{{.*}}[%[[I]]]
// CHECK: %[[GATHER_IDX:.*]] = index_cast %[[L0]]
// CHECK: affine.for %[[J:.*]] = 0 to 32 {
// CHECK: %[[VALUE:.*]] = load %{{.*}}[%[[GATHER_IDX]], %[[J]]]
// CHECK: affine.store %[[VALUE]], {{.*}}[%[[I]], %[[J]]]
func @simple_gather(%arg0: !ng.tensor<16x!ng.i64>, %arg1: !ng.tensor<512x32xf32>) -> !ng.tensor<16x32xf32> {
%0 = "ng.gather"(%arg1, %arg0) {axis = 0 : i64} : (!ng.tensor<512x32xf32>, !ng.tensor<16x!ng.i64>) -> !ng.tensor<16x32xf32>
"ng.return"(%0) : (!ng.tensor<16x32xf32>) -> ()
......@@ -18,7 +19,116 @@ func @simple_gather(%arg0: !ng.tensor<16x!ng.i64>, %arg1: !ng.tensor<512x32xf32>
// -----
// Equal Op
// CHECK-LABEL: func @simple_equal
// CHECK: affine.for %[[I:.*]] = 0 to 2
// CHECK-NEXT: affine.for %[[J:.*]] = 0 to 2
// CHECK-NEXT: %[[C1:.*]] = constant 0 : i8
// CHECK-NEXT: %[[C2:.*]] = constant 1 : i8
// CHECK: %[[O1:.*]] = affine.load %{{.*}}[%[[I]], %[[J]]] : memref<2x2xf32>
// CHECK: %[[O2:.*]] = affine.load %{{.*}}[%[[I]], %[[J]]] : memref<2x2xf32>
// CHECK-NEXT: %[[R1:.*]] = cmpf "oeq", %[[O2]], %[[O1]] : f32
// CHECK: %[[R2:.*]] = select %[[R1]], %[[C2]], %[[C1]] : i8
// CHECK-NEXT: affine.store %[[R2]], %{{.*}}[%[[I]], %[[J]]]
func @simple_equal(%arg0: !ng.tensor<2x2xf32>, %arg1: !ng.tensor<2x2xf32>) -> !ng.tensor<2x2x!ng.u8>{
%0 = "ng.equal"(%arg1, %arg0) : (!ng.tensor<2x2xf32>, !ng.tensor<2x2xf32>) -> !ng.tensor<2x2x!ng.u8>
"ng.return"(%0) : (!ng.tensor<2x2x!ng.u8>) -> ()
}
// -----
// NotEqual Op
// CHECK-LABEL: func @simple_notequal
// CHECK: affine.for %[[I:.*]] = 0 to 2
// CHECK-NEXT: affine.for %[[J:.*]] = 0 to 2
// CHECK-NEXT: %[[C1:.*]] = constant 0 : i8
// CHECK-NEXT: %[[C2:.*]] = constant 1 : i8
// CHECK: %[[O1:.*]] = affine.load %{{.*}}[%[[I]], %[[J]]] : memref<2x2xf32>
// CHECK: %[[O2:.*]] = affine.load %{{.*}}[%[[I]], %[[J]]] : memref<2x2xf32>
// CHECK-NEXT: %[[R1:.*]] = cmpf "one", %[[O2]], %[[O1]] : f32
// CHECK: %[[R2:.*]] = select %[[R1]], %[[C2]], %[[C1]] : i8
// CHECK-NEXT: affine.store %[[R2]], %{{.*}}[%[[I]], %[[J]]]
func @simple_notequal(%arg0: !ng.tensor<2x2xf32>, %arg1: !ng.tensor<2x2xf32>) -> !ng.tensor<2x2x!ng.u8>{
%0 = "ng.not.equal"(%arg1, %arg0) : (!ng.tensor<2x2xf32>, !ng.tensor<2x2xf32>) -> !ng.tensor<2x2x!ng.u8>
"ng.return"(%0) : (!ng.tensor<2x2x!ng.u8>) -> ()
}
// -----
// Greater Op
// CHECK-LABEL: func @simple_greater
// CHECK: affine.for %[[I:.*]] = 0 to 2
// CHECK-NEXT: affine.for %[[J:.*]] = 0 to 2
// CHECK-NEXT: %[[C1:.*]] = constant 0 : i8
// CHECK-NEXT: %[[C2:.*]] = constant 1 : i8
// CHECK: %[[O1:.*]] = affine.load %{{.*}}[%[[I]], %[[J]]] : memref<2x2xf32>
// CHECK: %[[O2:.*]] = affine.load %{{.*}}[%[[I]], %[[J]]] : memref<2x2xf32>
// CHECK-NEXT: %[[R1:.*]] = cmpf "ogt", %[[O2:.*]], %[[O1:.*]] : f32
// CHECK: %[[R2:.*]] = select %[[R1]], %[[C2]], %[[C1]] : i8
// CHECK-NEXT: affine.store %[[R2]], %{{.*}}[%[[I]], %[[J]]]
func @simple_greater(%arg0: !ng.tensor<2x2xf32>, %arg1: !ng.tensor<2x2xf32>) -> !ng.tensor<2x2x!ng.u8>{
%0 = "ng.greater"(%arg1, %arg0) : (!ng.tensor<2x2xf32>, !ng.tensor<2x2xf32>) -> !ng.tensor<2x2x!ng.u8>
"ng.return"(%0) : (!ng.tensor<2x2x!ng.u8>) -> ()
}
// -----
// GreaterEq Op
// CHECK-LABEL: func @simple_greatereq
// CHECK: affine.for %[[I:.*]] = 0 to 2
// CHECK-NEXT: affine.for %[[J:.*]] = 0 to 2
// CHECK-NEXT: %[[C1:.*]] = constant 0 : i8
// CHECK-NEXT: %[[C2:.*]] = constant 1 : i8
// CHECK: %[[O1:.*]] = affine.load %{{.*}}[%[[I]], %[[J]]] : memref<2x2xf32>
// CHECK: %[[O2:.*]] = affine.load %{{.*}}[%[[I]], %[[J]]] : memref<2x2xf32>
// CHECK-NEXT: %[[R1:.*]] = cmpf "oge", %[[O2]], %[[O1]] : f32
// CHECK: %[[R2:.*]] = select %[[R1]], %[[C2]], %[[C1]] : i8
// CHECK-NEXT: affine.store %[[R2]], %{{.*}}[%[[I]], %[[J]]]
func @simple_greatereq(%arg0: !ng.tensor<2x2xf32>, %arg1: !ng.tensor<2x2xf32>) -> !ng.tensor<2x2x!ng.u8>{
%0 = "ng.greater.eq"(%arg1, %arg0) : (!ng.tensor<2x2xf32>, !ng.tensor<2x2xf32>) -> !ng.tensor<2x2x!ng.u8>
"ng.return"(%0) : (!ng.tensor<2x2x!ng.u8>) -> ()
}
// -----
// Less Op
// CHECK-LABEL: func @simple_less
// CHECK: affine.for %[[I:.*]] = 0 to 2
// CHECK-NEXT: affine.for %[[J:.*]] = 0 to 2
// CHECK-NEXT: %[[C1:.*]] = constant 0 : i8
// CHECK-NEXT: %[[C2:.*]] = constant 1 : i8
// CHECK: %[[O1:.*]] = affine.load %{{.*}}[%[[I]], %[[J]]] : memref<2x2xf32>
// CHECK: %[[O2:.*]] = affine.load %{{.*}}[%[[I]], %[[J]]] : memref<2x2xf32>
// CHECK-NEXT: %[[R1:.*]] = cmpf "olt", %[[O2]], %[[O1]] : f32
// CHECK: %[[R2:.*]] = select %[[R1]], %[[C2]], %[[C1]] : i8
// CHECK-NEXT: affine.store %[[R2]], %{{.*}}[%[[I]], %[[J]]]
func @simple_less(%arg0: !ng.tensor<2x2xf32>, %arg1: !ng.tensor<2x2xf32>) -> !ng.tensor<2x2x!ng.u8>{
%0 = "ng.less"(%arg1, %arg0) : (!ng.tensor<2x2xf32>, !ng.tensor<2x2xf32>) -> !ng.tensor<2x2x!ng.u8>
"ng.return"(%0) : (!ng.tensor<2x2x!ng.u8>) -> ()
}
// -----
// LessEq Op
// CHECK-LABEL: func @simple_lesseq
// CHECK: affine.for %[[I:.*]] = 0 to 2
// CHECK-NEXT: affine.for %[[J:.*]] = 0 to 2
// CHECK-NEXT: %[[C1:.*]] = constant 0 : i8
// CHECK-NEXT: %[[C2:.*]] = constant 1 : i8
// CHECK: %[[O1:.*]] = affine.load %{{.*}}[%[[I]], %[[J]]] : memref<2x2xf32>
// CHECK: %[[O2:.*]] = affine.load %{{.*}}[%[[I]], %[[J]]] : memref<2x2xf32>
// CHECK-NEXT: %[[R1:.*]] = cmpf "ole", %[[O2]], %[[O1]] : f32
// CHECK: %[[R2:.*]] = select %[[R1]], %[[C2]], %[[C1]] : i8
// CHECK-NEXT: affine.store %[[R2]], %{{.*}}[%[[I]], %[[J]]]
func @simple_lesseq(%arg0: !ng.tensor<2x2xf32>, %arg1: !ng.tensor<2x2xf32>) -> !ng.tensor<2x2x!ng.u8>{
%0 = "ng.less.eq"(%arg1, %arg0) : (!ng.tensor<2x2xf32>, !ng.tensor<2x2xf32>) -> !ng.tensor<2x2x!ng.u8>
"ng.return"(%0) : (!ng.tensor<2x2x!ng.u8>) -> ()
}
// -----
// Dot Op
// CHECK-LABEL: func @simple_dot
// CHECK: affine.for %[[I:.*]] = 0 to 16
// CHECK-NEXT: affine.for %[[J:.*]] = 0 to 32
// CHECK-NEXT: affine.store %{{.*}}, %[[RESULT:.*]][%[[I]], %[[J]]]
......
......@@ -6,8 +6,49 @@
// CHECK-LABEL: func @add_float
func @add_float(%arg0: !ng.tensor<2x2xf32>, %arg1: !ng.tensor<2x2xf32>) -> !ng.tensor<2x2xf32> {
// CHECK: %{{[0-9]+}} = "ng.add"(%{{.*}}, %{{.*}}) : (!ng.tensor<2x2xf32>, !ng.tensor<2x2xf32>) -> !ng.tensor<2x2xf32>
// CHECK: %{{.*}} = "ng.add"(%{{.*}}, %{{.*}}) : (!ng.tensor<2x2xf32>, !ng.tensor<2x2xf32>) -> !ng.tensor<2x2xf32>
%0 = "ng.add"(%arg1, %arg0) : (!ng.tensor<2x2xf32>, !ng.tensor<2x2xf32>) -> !ng.tensor<2x2xf32>
"ng.return"(%0) : (!ng.tensor<2x2xf32>) -> ()
}
// CHECK-LABEL: func @equal_float
func @equal_float(%arg0: !ng.tensor<2x2xf32>, %arg1: !ng.tensor<2x2xf32>) -> !ng.tensor<2x2x!ng.u8> {
// CHECK: %{{.*}} = "ng.equal"(%{{.*}}, %{{.*}}) : (!ng.tensor<2x2xf32>, !ng.tensor<2x2xf32>) -> !ng.tensor<2x2x!ng.u8>
%0 = "ng.equal"(%arg1, %arg0) : (!ng.tensor<2x2xf32>, !ng.tensor<2x2xf32>) -> !ng.tensor<2x2x!ng.u8>
"ng.return"(%0) : (!ng.tensor<2x2x!ng.u8>) -> ()
}
// CHECK-LABEL: func @notequal_float
func @notequal_float(%arg0: !ng.tensor<2x2xf32>, %arg1: !ng.tensor<2x2xf32>) -> !ng.tensor<2x2x!ng.u8> {
// CHECK: %{{.*}} = "ng.not.equal"(%{{.*}}, %{{.*}}) : (!ng.tensor<2x2xf32>, !ng.tensor<2x2xf32>) -> !ng.tensor<2x2x!ng.u8>
%0 = "ng.not.equal"(%arg1, %arg0) : (!ng.tensor<2x2xf32>, !ng.tensor<2x2xf32>) -> !ng.tensor<2x2x!ng.u8>
"ng.return"(%0) : (!ng.tensor<2x2x!ng.u8>) -> ()
}
// CHECK-LABEL: func @greater_float
func @greater_float(%arg0: !ng.tensor<2x2xf32>, %arg1: !ng.tensor<2x2xf32>) -> !ng.tensor<2x2x!ng.u8> {
// CHECK: %{{.*}} = "ng.greater"(%{{.*}}, %{{.*}}) : (!ng.tensor<2x2xf32>, !ng.tensor<2x2xf32>) -> !ng.tensor<2x2x!ng.u8>
%0 = "ng.greater"(%arg1, %arg0) : (!ng.tensor<2x2xf32>, !ng.tensor<2x2xf32>) -> !ng.tensor<2x2x!ng.u8>
"ng.return"(%0) : (!ng.tensor<2x2x!ng.u8>) -> ()
}
// CHECK-LABEL: func @greatereq_float
func @greatereq_float(%arg0: !ng.tensor<2x2xf32>, %arg1: !ng.tensor<2x2xf32>) -> !ng.tensor<2x2x!ng.u8> {
// CHECK: %{{.*}} = "ng.greater.eq"(%{{.*}}, %{{.*}}) : (!ng.tensor<2x2xf32>, !ng.tensor<2x2xf32>) -> !ng.tensor<2x2x!ng.u8>
%0 = "ng.greater.eq"(%arg1, %arg0) : (!ng.tensor<2x2xf32>, !ng.tensor<2x2xf32>) -> !ng.tensor<2x2x!ng.u8>
"ng.return"(%0) : (!ng.tensor<2x2x!ng.u8>) -> ()
}
// CHECK-LABEL: func @less_float
func @less_float(%arg0: !ng.tensor<2x2xf32>, %arg1: !ng.tensor<2x2xf32>) -> !ng.tensor<2x2x!ng.u8> {
// CHECK: %{{.*}} = "ng.less"(%{{.*}}, %{{.*}}) : (!ng.tensor<2x2xf32>, !ng.tensor<2x2xf32>) -> !ng.tensor<2x2x!ng.u8>
%0 = "ng.less"(%arg1, %arg0) : (!ng.tensor<2x2xf32>, !ng.tensor<2x2xf32>) -> !ng.tensor<2x2x!ng.u8>
"ng.return"(%0) : (!ng.tensor<2x2x!ng.u8>) -> ()
}
// CHECK-LABEL: func @lesseq_float
func @lesseq_float(%arg0: !ng.tensor<2x2xf32>, %arg1: !ng.tensor<2x2xf32>) -> !ng.tensor<2x2x!ng.u8> {
// CHECK: %{{.*}} = "ng.less.eq"(%{{.*}}, %{{.*}}) : (!ng.tensor<2x2xf32>, !ng.tensor<2x2xf32>) -> !ng.tensor<2x2x!ng.u8>
%0 = "ng.less.eq"(%arg1, %arg0) : (!ng.tensor<2x2xf32>, !ng.tensor<2x2xf32>) -> !ng.tensor<2x2x!ng.u8>
"ng.return"(%0) : (!ng.tensor<2x2x!ng.u8>) -> ()
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment