Commit 073db8fb authored by Pruthvi's avatar Pruthvi Committed by Scott Cyphers

- MLIR Binary ElementWise Op (#3223)

* - templatize computing binary elementwise
- added lowering support for Add, Sub, Multiply, Divide

* - Added Support for Greater and less Op

* -Add support for Minimum and Maximum

* use edsc::intrinsics::select instead of terenary operator

* Addressed PR comments

* - return after the conditional check
parent 5f0391d3
......@@ -26,9 +26,16 @@
#include "ngraph/op/add.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/index_reduction.hpp"
#include "ngraph/type/element_type.hpp"
......@@ -296,6 +303,48 @@ namespace ngraph
return compiler.create_binary_op<mlir::NGAddOp>(ng_node);
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Subtract)
{
return compiler.create_binary_op<mlir::NGSubOp>(ng_node);
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Multiply)
{
return compiler.create_binary_op<mlir::NGMulOp>(ng_node);
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Divide)
{
return compiler.create_binary_op<mlir::NGDivOp>(ng_node);
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Greater)
{
return compiler.create_binary_op<mlir::NGGreaterOp>(ng_node);
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Less)
{
return compiler.create_binary_op<mlir::NGLessOp>(ng_node);
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Maximum)
{
return compiler.create_binary_op<mlir::NGMaxOp>(ng_node);
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Minimum)
{
return compiler.create_binary_op<mlir::NGMinOp>(ng_node);
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMax)
{
......
......@@ -83,6 +83,12 @@ namespace
PatternRewriter& rewriter,
DialectLoweringPass& m_pass);
template <typename OP>
void lower_binary_elementwise(Operation* op,
ArrayRef<Value*> operands,
PatternRewriter& rewriter,
DialectLoweringPass& m_pass);
/// Conversion from types in the nGraph dialect to the Standard dialect.
class NGraphTypeConverter : public TypeConverter
{
......@@ -363,42 +369,50 @@ namespace
// ADD
REWRITER(NGAddOp)
{
lower_binary_elementwise<mlir::NGAddOp>(op, operands, rewriter, m_pass);
return matchSuccess();
}
REWRITER(NGSubOp)
{
auto add = cast<NGAddOp>(op);
auto loc = add.getLoc();
lower_binary_elementwise<mlir::NGSubOp>(op, operands, rewriter, m_pass);
return matchSuccess();
}
auto result = m_pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result->getType().isa<MemRefType>());
// Note that builder's current function is still the original function body.
// use getBlock to get the new block instead.
REWRITER(NGMulOp)
{
lower_binary_elementwise<mlir::NGMulOp>(op, operands, rewriter, m_pass);
return matchSuccess();
}
// get new operands
Value* lhs = operands[0];
Value* rhs = operands[1];
REWRITER(NGDivOp)
{
lower_binary_elementwise<mlir::NGDivOp>(op, operands, rewriter, m_pass);
return matchSuccess();
}
ScopedContext scope(rewriter, loc);
// Views
MemRefView vRes(result), vLHS(lhs), vRHS(rhs);
// Index Values
IndexedValue iRes(result), iLHS(lhs), iRHS(rhs);
// Bounds Index Handles
auto lbs = vLHS.getLbs();
auto ubs = vLHS.getUbs();
// Loop induction vars
auto ivs = IndexHandle::makeIndexHandles(vLHS.rank());
auto pivs = IndexHandle::makeIndexHandlePointers(ivs);
// Steps
auto steps = vLHS.getSteps();
// clang-format off
LoopNestBuilder(pivs, lbs, ubs, steps)(
// single stmt body
[&] {
iRes(ivs) = iLHS(ivs) + iRHS(ivs);
});
// clang-format on
rewriter.replaceOp(op, {result});
REWRITER(NGGreaterOp)
{
lower_binary_elementwise<mlir::NGGreaterOp>(op, operands, rewriter, m_pass);
return matchSuccess();
}
REWRITER(NGLessOp)
{
lower_binary_elementwise<mlir::NGLessOp>(op, operands, rewriter, m_pass);
return matchSuccess();
}
REWRITER(NGMaxOp)
{
lower_binary_elementwise<mlir::NGMaxOp>(op, operands, rewriter, m_pass);
return matchSuccess();
}
REWRITER(NGMinOp)
{
lower_binary_elementwise<mlir::NGMinOp>(op, operands, rewriter, m_pass);
return matchSuccess();
}
......@@ -541,6 +555,81 @@ namespace
#undef REWRITER
template <typename OP>
void lower_binary_elementwise(Operation* op,
ArrayRef<Value*> operands,
PatternRewriter& rewriter,
DialectLoweringPass& m_pass)
{
auto loc = cast<OP>(op).getLoc();
auto result = m_pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result->getType().isa<MemRefType>());
// get new operands
Value* lhs = operands[0];
Value* rhs = operands[1];
ScopedContext scope(rewriter, loc);
// Views
MemRefView vRes(result), vLHS(lhs), vRHS(rhs);
// Index Values
IndexedValue iRes(result), iLHS(lhs), iRHS(rhs);
// Bounds Index Handles
auto lbs = vLHS.getLbs();
auto ubs = vLHS.getUbs();
// Loop induction vars
auto ivs = IndexHandle::makeIndexHandles(vLHS.rank());
auto pivs = IndexHandle::makeIndexHandlePointers(ivs);
// Steps
auto steps = vLHS.getSteps();
LoopNestBuilder(pivs, lbs, ubs, steps)(
// single stmt body
[&] {
if (isa<NGAddOp>(op))
{
iRes(ivs) = iLHS(ivs) + iRHS(ivs);
}
else if (isa<NGSubOp>(op))
{
iRes(ivs) = iLHS(ivs) - iRHS(ivs);
}
else if (isa<NGMulOp>(op))
{
iRes(ivs) = iLHS(ivs) * iRHS(ivs);
}
else if (isa<NGDivOp>(op))
{
iRes(ivs) = iLHS(ivs) / iRHS(ivs);
}
else if (isa<NGGreaterOp>(op))
{
iRes(ivs) = ValueHandle(iLHS(ivs)) > ValueHandle(iRHS(ivs));
}
else if (isa<NGLessOp>(op))
{
iRes(ivs) = ValueHandle(iLHS(ivs)) < ValueHandle(iRHS(ivs));
}
else if (isa<NGMaxOp>(op))
{
iRes(ivs) =
edsc::intrinsics::select(ValueHandle(iLHS(ivs)) > ValueHandle(iRHS(ivs)),
ValueHandle(iLHS(ivs)),
ValueHandle(iRHS(ivs)));
}
else if (isa<NGMinOp>(op))
{
iRes(ivs) =
edsc::intrinsics::select(ValueHandle(iLHS(ivs)) < ValueHandle(iRHS(ivs)),
ValueHandle(iLHS(ivs)),
ValueHandle(iRHS(ivs)));
}
else
{
NGRAPH_CHECK(false, "Unsupported op");
}
});
rewriter.replaceOp(op, {result});
}
template <typename RedOp>
void lowerIndexReduction(Operation* op,
ArrayRef<Value*> operands,
......
......@@ -26,8 +26,15 @@
MLIR_OP(NGAddOp)
MLIR_OP(NGArgMaxRedOp)
MLIR_OP(NGArgMinRedOp)
MLIR_OP(NGDivOp)
MLIR_OP(NGDotOp)
MLIR_OP(NGGreaterOp)
MLIR_OP(NGLessOp)
MLIR_OP(NGMulOp)
MLIR_OP(NGMaxOp)
MLIR_OP(NGMinOp)
MLIR_OP(NGReluOp)
MLIR_OP(NGSubOp)
MLIR_LAST_OP(NGReturnOp)
#undef MLIR_OP
......
......@@ -6,8 +6,15 @@
MLIR_OP(Add)
MLIR_OP(ArgMin)
MLIR_OP(ArgMax)
MLIR_OP(Relu)
MLIR_OP(Divide)
MLIR_OP(Dot)
MLIR_OP(Greater)
MLIR_OP(Less)
MLIR_OP(Maximum)
MLIR_OP(Minimum)
MLIR_OP(Multiply)
MLIR_OP(Subtract)
MLIR_OP(Relu)
// Add new supported ops here
#undef MLIR_OP
......@@ -21,10 +21,17 @@
#include "ngraph/op/add.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/relu.hpp"
#include "ngraph/op/subtract.hpp"
using namespace ngraph::descriptor;
using namespace ngraph::op;
......@@ -275,13 +282,49 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
{
return false;
}
else
{
return true;
}
}
if (TI(ngraph::op::ArgMin) == TI(*node) || TI(ngraph::op::ArgMax) == TI(*node))
{
// TODO: Remove this when MLIR has float point cmp support
if (!node->input(0).get_element_type().is_integral())
{
return false;
}
else
{
return true;
}
}
if (TI(ngraph::op::Maximum) == TI(*node) || TI(ngraph::op::Minimum) == TI(*node))
{
// TODO: Remove this when MLIR has float point cmp support
if (!node->input(0).get_element_type().is_integral())
{
return false;
}
else
{
return true;
}
}
if (TI(ngraph::op::Greater) == TI(*node) || TI(ngraph::op::Less) == TI(*node))
{
// TODO: Remove this when MLIR has float point cmp support
if (!node->input(0).get_element_type().is_integral())
{
return false;
}
else
{
return true;
}
}
// Relu is supported for integer types only until MLIR adds support for lowering !std.CmpF to LLVM dialect
......@@ -291,6 +334,10 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
{
return false;
}
else
{
return true;
}
}
return true;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment