Commit d1af0bb7 authored by Nagy Mostafa's avatar Nagy Mostafa Committed by Scott Cyphers

[MLIR] Support MLIR lowering of Relu (#3197)

* Support MLIR lowering of Relu

* Use EDSC comparison

* style-apply

* Use .inc file for Conversion classes list

* Disable i32 Relu for plaidml
parent a655e1cc
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "ngraph/op/argmin.hpp" #include "ngraph/op/argmin.hpp"
#include "ngraph/op/dot.hpp" #include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp" #include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/util/index_reduction.hpp" #include "ngraph/op/util/index_reduction.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
...@@ -306,11 +307,18 @@ namespace ngraph ...@@ -306,11 +307,18 @@ namespace ngraph
{ {
return compiler.create_index_reduction<mlir::NGArgMinRedOp>(ng_node); return compiler.create_index_reduction<mlir::NGArgMinRedOp>(ng_node);
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Dot) mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Dot)
{ {
return compiler.create_binary_op<mlir::NGDotOp>(ng_node); return compiler.create_binary_op<mlir::NGDotOp>(ng_node);
} }
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Relu)
{
return compiler.create_unary_op<mlir::NGReluOp>(ng_node);
}
} }
} }
} }
...@@ -320,6 +328,16 @@ const MLIRCompiler::MLIRCompOpMap MLIRCompiler::op_dispatcher{ ...@@ -320,6 +328,16 @@ const MLIRCompiler::MLIRCompOpMap MLIRCompiler::op_dispatcher{
#include "ops_supported.inc" #include "ops_supported.inc"
}; };
template <typename UnaryOp>
mlir::Value* MLIRCompiler::create_unary_op(const ngraph::Node* ng_node)
{
auto lhs = ng_node->get_argument(0)->get_output_tensor_ptr();
auto lhs_v = get_tensor_value(lhs.get()).m_value;
auto res_type = get_mlir_type(ng_node->get_output_tensor_ptr().get());
return m_builder->create<UnaryOp>(mlir::UnknownLoc::get(&m_context), res_type, lhs_v)
.getResult();
}
template <typename BinOp> template <typename BinOp>
mlir::Value* MLIRCompiler::create_binary_op(const ngraph::Node* ng_node) mlir::Value* MLIRCompiler::create_binary_op(const ngraph::Node* ng_node)
{ {
......
...@@ -105,6 +105,9 @@ namespace ngraph ...@@ -105,6 +105,9 @@ namespace ngraph
"' in MLIR Compiler"); "' in MLIR Compiler");
} }
template <typename UnaryOp>
mlir::Value* create_unary_op(const ngraph::Node* ng_node);
template <typename BinOp> template <typename BinOp>
mlir::Value* create_binary_op(const ngraph::Node* ng_node); mlir::Value* create_binary_op(const ngraph::Node* ng_node);
......
...@@ -149,6 +149,7 @@ def NGSinhOp : NG_Unary_Arith_Op<"sinh">; ...@@ -149,6 +149,7 @@ def NGSinhOp : NG_Unary_Arith_Op<"sinh">;
def NGTanOp : NG_Unary_Arith_Op<"tan">; def NGTanOp : NG_Unary_Arith_Op<"tan">;
def NGTanhOp : NG_Unary_Arith_Op<"tanh">; def NGTanhOp : NG_Unary_Arith_Op<"tanh">;
def NGSqrtOp : NG_Unary_Arith_Op<"sqrt">; def NGSqrtOp : NG_Unary_Arith_Op<"sqrt">;
def NGReluOp : NG_Unary_Arith_Op<"relu">;
// Binary Operations // Binary Operations
def NGAddOp : NG_Binary_Arith_Op<"add", [Commutative]>; def NGAddOp : NG_Binary_Arith_Op<"add", [Commutative]>;
......
...@@ -59,6 +59,21 @@ namespace ...@@ -59,6 +59,21 @@ namespace
DialectLoweringPass& m_pass; DialectLoweringPass& m_pass;
}; };
// Conversion classes declarations
#define MLIR_OP(OP) \
class OP##Conversion : public NGraphOpLowering \
{ \
public: \
explicit OP##Conversion(mlir::MLIRContext* context, DialectLoweringPass& pass) \
: NGraphOpLowering(mlir::OP::getOperationName(), context, pass) \
{ \
} \
\
PatternMatchResult matchAndRewrite(Operation* op, \
ArrayRef<Value*> operands, \
PatternRewriter& rewriter) const override; \
};
#include "op_lowerers.inc" #include "op_lowerers.inc"
// Helpers // Helpers
...@@ -147,11 +162,11 @@ namespace ...@@ -147,11 +162,11 @@ namespace
void DialectLoweringPass::populateNGraphToAffineConversionPatterns( void DialectLoweringPass::populateNGraphToAffineConversionPatterns(
OwningRewritePatternList& patterns) OwningRewritePatternList& patterns)
{ {
RewriteListBuilder<NGAddOpConversion, #define MLIR_OP(OP) OP##Conversion,
NGArgMaxRedOpConversion, #define MLIR_LAST_OP(OP) OP##Conversion
NGArgMinRedOpConversion, RewriteListBuilder<
NGDotOpConversion, #include "op_lowerers.inc"
NGReturnOpConversion>::build(patterns, &getContext(), *this); >::build(patterns, &getContext(), *this);
} }
void DialectLoweringPass::findOutputValues() void DialectLoweringPass::findOutputValues()
...@@ -345,6 +360,7 @@ namespace ...@@ -345,6 +360,7 @@ namespace
// ADD // ADD
REWRITER(NGAddOp) REWRITER(NGAddOp)
{ {
auto add = cast<NGAddOp>(op); auto add = cast<NGAddOp>(op);
auto loc = add.getLoc(); auto loc = add.getLoc();
...@@ -395,6 +411,61 @@ namespace ...@@ -395,6 +411,61 @@ namespace
return matchSuccess(); return matchSuccess();
} }
// Relu
REWRITER(NGReluOp)
{
auto loc = cast<NGReluOp>(op).getLoc();
auto result = m_pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result->getType().isa<MemRefType>());
// Note that builder's current function is still the original function body.
// use getBlock to get the new block instead.
// get new operands
Value* lhs = operands[0];
ScopedContext scope(rewriter, loc);
// Views
MemRefView vRes(result), vLHS(lhs);
// Index Values
IndexedValue iRes(result), iLHS(lhs);
// Bounds Index Handles
auto lbs = vLHS.getLbs();
auto ubs = vLHS.getUbs();
// Loop induction vars
auto ivs = IndexHandle::makeIndexHandles(vLHS.rank());
auto pivs = IndexHandle::makeIndexHandlePointers(ivs);
// Steps
auto steps = vLHS.getSteps();
NGRAPH_CHECK(lhs->getType().isa<MemRefType>());
Type elemTy = lhs->getType().dyn_cast<MemRefType>().getElementType();
NGRAPH_CHECK(!elemTy.isa<FloatType>(),
"NGReluOp with float element type should not be lowered until MLIR supports "
"lowering !std.CmpF");
LoopNestBuilder(pivs, lbs, ubs, steps)([&] {
ValueHandle val = iLHS(ivs);
if (auto floatTy = elemTy.dyn_cast<FloatType>())
{
ValueHandle zero = intrinsics::constant_float(llvm::APFloat(0.0f), floatTy);
iRes(ivs) = intrinsics::select(val > zero, val, zero);
}
else if (auto intTy = elemTy.dyn_cast<IntegerType>())
{
ValueHandle zero = intrinsics::constant_int(0, intTy.getWidth());
iRes(ivs) = intrinsics::select(val > zero, val, zero);
}
else
{
NGRAPH_CHECK(false, "Unsupported type for Relu");
}
});
rewriter.replaceOp(op, {result});
return matchSuccess();
}
REWRITER(NGDotOp) REWRITER(NGDotOp)
{ {
auto dot = cast<NGDotOp>(op); auto dot = cast<NGDotOp>(op);
......
...@@ -14,27 +14,21 @@ ...@@ -14,27 +14,21 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
// Add new dialect ops with lowering support to this file
#ifndef MLIR_OP
#define MLIR_OP
#endif
// Add new dialect ops lowerers to this file #ifndef MLIR_LAST_OP
#define MLIR_LAST_OP(OP) MLIR_OP(OP)
#endif
#define DECL_OP_CONV(OP) \ MLIR_OP(NGAddOp)
class OP##Conversion : public NGraphOpLowering \ MLIR_OP(NGArgMaxRedOp)
{ \ MLIR_OP(NGArgMinRedOp)
public: \ MLIR_OP(NGDotOp)
explicit OP##Conversion(mlir::MLIRContext* context, DialectLoweringPass& pass) \ MLIR_OP(NGReluOp)
: NGraphOpLowering(mlir::OP::getOperationName(), context, pass) \ MLIR_LAST_OP(NGReturnOp)
{ \
} \
\
PatternMatchResult matchAndRewrite(Operation* op, \
ArrayRef<Value*> operands, \
PatternRewriter& rewriter) const override; \
};
DECL_OP_CONV(NGAddOp) #undef MLIR_OP
DECL_OP_CONV(NGArgMaxRedOp) #undef MLIR_LAST_OP
DECL_OP_CONV(NGArgMinRedOp)
DECL_OP_CONV(NGDotOp)
DECL_OP_CONV(NGReturnOp)
#undef DECL_OP_CONV
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
MLIR_OP(Add) MLIR_OP(Add)
MLIR_OP(ArgMin) MLIR_OP(ArgMin)
MLIR_OP(ArgMax) MLIR_OP(ArgMax)
MLIR_OP(Relu)
MLIR_OP(Dot) MLIR_OP(Dot)
// Add new supported ops here // Add new supported ops here
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
//***************************************************************************** //*****************************************************************************
#include "mlir_subgraph_extraction.hpp" #include "mlir_subgraph_extraction.hpp"
#include "ngraph/assertion.hpp" #include "ngraph/assertion.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
...@@ -24,6 +23,7 @@ ...@@ -24,6 +23,7 @@
#include "ngraph/op/dot.hpp" #include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp" #include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/relu.hpp"
using namespace ngraph::descriptor; using namespace ngraph::descriptor;
using namespace ngraph::op; using namespace ngraph::op;
...@@ -115,6 +115,15 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node ...@@ -115,6 +115,15 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
if (!node->input(0).get_element_type().is_integral()) if (!node->input(0).get_element_type().is_integral())
return false; return false;
} }
// Relu is supported for integer types only until MLIR adds support for lowering !std.CmpF to LLVM dialect
if (TI(ngraph::op::Relu) == TI(*node))
{
if (!node->get_element_type().is_integral())
{
return false;
}
}
return true; return true;
} }
......
...@@ -171,6 +171,7 @@ prelu ...@@ -171,6 +171,7 @@ prelu
hardsigmoid hardsigmoid
prelu_shared_slope prelu_shared_slope
prelu_negative_slope prelu_negative_slope
relu_2Dfprop_i32
conv_bias_1d conv_bias_1d
conv_bias_2d conv_bias_2d
conv_bias_3d conv_bias_3d
......
...@@ -5286,6 +5286,26 @@ NGRAPH_TEST(${BACKEND_NAME}, relu_2Dfprop) ...@@ -5286,6 +5286,26 @@ NGRAPH_TEST(${BACKEND_NAME}, relu_2Dfprop)
EXPECT_TRUE(test::all_close_f(read_vector<float>(result), expected, MIN_FLOAT_TOLERANCE_BITS)); EXPECT_TRUE(test::all_close_f(read_vector<float>(result), expected, MIN_FLOAT_TOLERANCE_BITS));
} }
NGRAPH_TEST(${BACKEND_NAME}, relu_2Dfprop_i32)
{
auto shape_a = Shape{2, 5};
auto A = make_shared<op::Parameter>(element::i32, shape_a);
auto relu = make_shared<op::Relu>(A);
auto shape_rt = Shape{2, 5};
auto f = make_shared<Function>(relu, ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
auto a = backend->create_tensor(element::i32, shape_a);
copy_data(a, vector<int32_t>{1, 8, -8, 17, -2, 1, 8, -8, 17, -1});
auto result = backend->create_tensor(element::i32, shape_rt);
vector<int32_t> expected{1, 8, 0, 17, 0, 1, 8, 0, 17, 0};
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_EQ(expected, read_vector<int32_t>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, relu_4Dfprop) NGRAPH_TEST(${BACKEND_NAME}, relu_4Dfprop)
{ {
auto shape_a = Shape{2, 2, 2, 2}; auto shape_a = Shape{2, 2, 2, 2};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment