Commit d1af0bb7 authored by Nagy Mostafa's avatar Nagy Mostafa Committed by Scott Cyphers

[MLIR] Support MLIR lowering of Relu (#3197)

* Support MLIR lowering of Relu

* Use EDSC comparison

* style-apply

* Use .inc file for Conversion classes list

* Disable i32 Relu for plaidml
parent a655e1cc
......@@ -28,6 +28,7 @@
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/util/index_reduction.hpp"
#include "ngraph/type/element_type.hpp"
......@@ -306,11 +307,18 @@ namespace ngraph
{
return compiler.create_index_reduction<mlir::NGArgMinRedOp>(ng_node);
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Dot)
{
return compiler.create_binary_op<mlir::NGDotOp>(ng_node);
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Relu)
{
return compiler.create_unary_op<mlir::NGReluOp>(ng_node);
}
}
}
}
......@@ -320,6 +328,16 @@ const MLIRCompiler::MLIRCompOpMap MLIRCompiler::op_dispatcher{
#include "ops_supported.inc"
};
template <typename UnaryOp>
mlir::Value* MLIRCompiler::create_unary_op(const ngraph::Node* ng_node)
{
auto lhs = ng_node->get_argument(0)->get_output_tensor_ptr();
auto lhs_v = get_tensor_value(lhs.get()).m_value;
auto res_type = get_mlir_type(ng_node->get_output_tensor_ptr().get());
return m_builder->create<UnaryOp>(mlir::UnknownLoc::get(&m_context), res_type, lhs_v)
.getResult();
}
template <typename BinOp>
mlir::Value* MLIRCompiler::create_binary_op(const ngraph::Node* ng_node)
{
......
......@@ -105,6 +105,9 @@ namespace ngraph
"' in MLIR Compiler");
}
template <typename UnaryOp>
mlir::Value* create_unary_op(const ngraph::Node* ng_node);
template <typename BinOp>
mlir::Value* create_binary_op(const ngraph::Node* ng_node);
......
......@@ -149,6 +149,7 @@ def NGSinhOp : NG_Unary_Arith_Op<"sinh">;
def NGTanOp : NG_Unary_Arith_Op<"tan">;
def NGTanhOp : NG_Unary_Arith_Op<"tanh">;
def NGSqrtOp : NG_Unary_Arith_Op<"sqrt">;
def NGReluOp : NG_Unary_Arith_Op<"relu">;
// Binary Operations
def NGAddOp : NG_Binary_Arith_Op<"add", [Commutative]>;
......
......@@ -59,6 +59,21 @@ namespace
DialectLoweringPass& m_pass;
};
// Conversion classes declarations
#define MLIR_OP(OP) \
class OP##Conversion : public NGraphOpLowering \
{ \
public: \
explicit OP##Conversion(mlir::MLIRContext* context, DialectLoweringPass& pass) \
: NGraphOpLowering(mlir::OP::getOperationName(), context, pass) \
{ \
} \
\
PatternMatchResult matchAndRewrite(Operation* op, \
ArrayRef<Value*> operands, \
PatternRewriter& rewriter) const override; \
};
#include "op_lowerers.inc"
// Helpers
......@@ -147,11 +162,11 @@ namespace
void DialectLoweringPass::populateNGraphToAffineConversionPatterns(
OwningRewritePatternList& patterns)
{
RewriteListBuilder<NGAddOpConversion,
NGArgMaxRedOpConversion,
NGArgMinRedOpConversion,
NGDotOpConversion,
NGReturnOpConversion>::build(patterns, &getContext(), *this);
#define MLIR_OP(OP) OP##Conversion,
#define MLIR_LAST_OP(OP) OP##Conversion
RewriteListBuilder<
#include "op_lowerers.inc"
>::build(patterns, &getContext(), *this);
}
void DialectLoweringPass::findOutputValues()
......@@ -345,6 +360,7 @@ namespace
// ADD
REWRITER(NGAddOp)
{
auto add = cast<NGAddOp>(op);
auto loc = add.getLoc();
......@@ -395,6 +411,61 @@ namespace
return matchSuccess();
}
// Relu
REWRITER(NGReluOp)
{
auto loc = cast<NGReluOp>(op).getLoc();
auto result = m_pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result->getType().isa<MemRefType>());
// Note that builder's current function is still the original function body.
// use getBlock to get the new block instead.
// get new operands
Value* lhs = operands[0];
ScopedContext scope(rewriter, loc);
// Views
MemRefView vRes(result), vLHS(lhs);
// Index Values
IndexedValue iRes(result), iLHS(lhs);
// Bounds Index Handles
auto lbs = vLHS.getLbs();
auto ubs = vLHS.getUbs();
// Loop induction vars
auto ivs = IndexHandle::makeIndexHandles(vLHS.rank());
auto pivs = IndexHandle::makeIndexHandlePointers(ivs);
// Steps
auto steps = vLHS.getSteps();
NGRAPH_CHECK(lhs->getType().isa<MemRefType>());
Type elemTy = lhs->getType().dyn_cast<MemRefType>().getElementType();
NGRAPH_CHECK(!elemTy.isa<FloatType>(),
"NGReluOp with float element type should not be lowered until MLIR supports "
"lowering !std.CmpF");
LoopNestBuilder(pivs, lbs, ubs, steps)([&] {
ValueHandle val = iLHS(ivs);
if (auto floatTy = elemTy.dyn_cast<FloatType>())
{
ValueHandle zero = intrinsics::constant_float(llvm::APFloat(0.0f), floatTy);
iRes(ivs) = intrinsics::select(val > zero, val, zero);
}
else if (auto intTy = elemTy.dyn_cast<IntegerType>())
{
ValueHandle zero = intrinsics::constant_int(0, intTy.getWidth());
iRes(ivs) = intrinsics::select(val > zero, val, zero);
}
else
{
NGRAPH_CHECK(false, "Unsupported type for Relu");
}
});
rewriter.replaceOp(op, {result});
return matchSuccess();
}
REWRITER(NGDotOp)
{
auto dot = cast<NGDotOp>(op);
......
......@@ -14,27 +14,21 @@
// limitations under the License.
//*****************************************************************************
// Add new dialect ops with lowering support to this file
#ifndef MLIR_OP
#define MLIR_OP
#endif
// Add new dialect ops lowerers to this file
#ifndef MLIR_LAST_OP
#define MLIR_LAST_OP(OP) MLIR_OP(OP)
#endif
#define DECL_OP_CONV(OP) \
class OP##Conversion : public NGraphOpLowering \
{ \
public: \
explicit OP##Conversion(mlir::MLIRContext* context, DialectLoweringPass& pass) \
: NGraphOpLowering(mlir::OP::getOperationName(), context, pass) \
{ \
} \
\
PatternMatchResult matchAndRewrite(Operation* op, \
ArrayRef<Value*> operands, \
PatternRewriter& rewriter) const override; \
};
MLIR_OP(NGAddOp)
MLIR_OP(NGArgMaxRedOp)
MLIR_OP(NGArgMinRedOp)
MLIR_OP(NGDotOp)
MLIR_OP(NGReluOp)
MLIR_LAST_OP(NGReturnOp)
DECL_OP_CONV(NGAddOp)
DECL_OP_CONV(NGArgMaxRedOp)
DECL_OP_CONV(NGArgMinRedOp)
DECL_OP_CONV(NGDotOp)
DECL_OP_CONV(NGReturnOp)
#undef DECL_OP_CONV
#undef MLIR_OP
#undef MLIR_LAST_OP
......@@ -6,6 +6,7 @@
MLIR_OP(Add)
MLIR_OP(ArgMin)
MLIR_OP(ArgMax)
MLIR_OP(Relu)
MLIR_OP(Dot)
// Add new supported ops here
......
......@@ -15,7 +15,6 @@
//*****************************************************************************
#include "mlir_subgraph_extraction.hpp"
#include "ngraph/assertion.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/add.hpp"
......@@ -24,6 +23,7 @@
#include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/relu.hpp"
using namespace ngraph::descriptor;
using namespace ngraph::op;
......@@ -115,6 +115,15 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
if (!node->input(0).get_element_type().is_integral())
return false;
}
// Relu is supported for integer types only until MLIR adds support for lowering !std.CmpF to LLVM dialect
if (TI(ngraph::op::Relu) == TI(*node))
{
if (!node->get_element_type().is_integral())
{
return false;
}
}
return true;
}
......
......@@ -171,6 +171,7 @@ prelu
hardsigmoid
prelu_shared_slope
prelu_negative_slope
relu_2Dfprop_i32
conv_bias_1d
conv_bias_2d
conv_bias_3d
......
......@@ -5286,6 +5286,26 @@ NGRAPH_TEST(${BACKEND_NAME}, relu_2Dfprop)
EXPECT_TRUE(test::all_close_f(read_vector<float>(result), expected, MIN_FLOAT_TOLERANCE_BITS));
}
NGRAPH_TEST(${BACKEND_NAME}, relu_2Dfprop_i32)
{
auto shape_a = Shape{2, 5};
auto A = make_shared<op::Parameter>(element::i32, shape_a);
auto relu = make_shared<op::Relu>(A);
auto shape_rt = Shape{2, 5};
auto f = make_shared<Function>(relu, ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
auto a = backend->create_tensor(element::i32, shape_a);
copy_data(a, vector<int32_t>{1, 8, -8, 17, -2, 1, 8, -8, 17, -1});
auto result = backend->create_tensor(element::i32, shape_rt);
vector<int32_t> expected{1, 8, 0, 17, 0, 1, 8, 0, 17, 0};
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_EQ(expected, read_vector<int32_t>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, relu_4Dfprop)
{
auto shape_a = Shape{2, 2, 2, 2};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment