Commit 073db8fb authored by Pruthvi's avatar Pruthvi Committed by Scott Cyphers

- MLIR Binary ElementWise Op (#3223)

* - templatize computing binary elementwise
- added lowering support for Add, Sub, Multiply, Divide

* - Added Support for Greater and less Op

* -Add support for Minimum and Maximum

* use edsc::intrinsics::select instead of terenary operator

* Addressed PR comments

* - return after the conditional check
parent 5f0391d3
...@@ -26,9 +26,16 @@ ...@@ -26,9 +26,16 @@
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/argmax.hpp" #include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp" #include "ngraph/op/argmin.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp" #include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp" #include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/relu.hpp" #include "ngraph/op/relu.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/index_reduction.hpp" #include "ngraph/op/util/index_reduction.hpp"
#include "ngraph/type/element_type.hpp" #include "ngraph/type/element_type.hpp"
...@@ -296,6 +303,48 @@ namespace ngraph ...@@ -296,6 +303,48 @@ namespace ngraph
return compiler.create_binary_op<mlir::NGAddOp>(ng_node); return compiler.create_binary_op<mlir::NGAddOp>(ng_node);
} }
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Subtract)
{
return compiler.create_binary_op<mlir::NGSubOp>(ng_node);
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Multiply)
{
return compiler.create_binary_op<mlir::NGMulOp>(ng_node);
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Divide)
{
return compiler.create_binary_op<mlir::NGDivOp>(ng_node);
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Greater)
{
return compiler.create_binary_op<mlir::NGGreaterOp>(ng_node);
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Less)
{
return compiler.create_binary_op<mlir::NGLessOp>(ng_node);
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Maximum)
{
return compiler.create_binary_op<mlir::NGMaxOp>(ng_node);
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Minimum)
{
return compiler.create_binary_op<mlir::NGMinOp>(ng_node);
}
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMax) mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMax)
{ {
......
...@@ -83,6 +83,12 @@ namespace ...@@ -83,6 +83,12 @@ namespace
PatternRewriter& rewriter, PatternRewriter& rewriter,
DialectLoweringPass& m_pass); DialectLoweringPass& m_pass);
template <typename OP>
void lower_binary_elementwise(Operation* op,
ArrayRef<Value*> operands,
PatternRewriter& rewriter,
DialectLoweringPass& m_pass);
/// Conversion from types in the nGraph dialect to the Standard dialect. /// Conversion from types in the nGraph dialect to the Standard dialect.
class NGraphTypeConverter : public TypeConverter class NGraphTypeConverter : public TypeConverter
{ {
...@@ -363,42 +369,50 @@ namespace ...@@ -363,42 +369,50 @@ namespace
// ADD // ADD
REWRITER(NGAddOp) REWRITER(NGAddOp)
{
lower_binary_elementwise<mlir::NGAddOp>(op, operands, rewriter, m_pass);
return matchSuccess();
}
REWRITER(NGSubOp)
{ {
auto add = cast<NGAddOp>(op); lower_binary_elementwise<mlir::NGSubOp>(op, operands, rewriter, m_pass);
auto loc = add.getLoc(); return matchSuccess();
}
auto result = m_pass.buildOutputDefs(op, rewriter)[0]; REWRITER(NGMulOp)
NGRAPH_CHECK(result->getType().isa<MemRefType>()); {
// Note that builder's current function is still the original function body. lower_binary_elementwise<mlir::NGMulOp>(op, operands, rewriter, m_pass);
// use getBlock to get the new block instead. return matchSuccess();
}
// get new operands REWRITER(NGDivOp)
Value* lhs = operands[0]; {
Value* rhs = operands[1]; lower_binary_elementwise<mlir::NGDivOp>(op, operands, rewriter, m_pass);
return matchSuccess();
}
ScopedContext scope(rewriter, loc); REWRITER(NGGreaterOp)
// Views {
MemRefView vRes(result), vLHS(lhs), vRHS(rhs); lower_binary_elementwise<mlir::NGGreaterOp>(op, operands, rewriter, m_pass);
// Index Values return matchSuccess();
IndexedValue iRes(result), iLHS(lhs), iRHS(rhs); }
// Bounds Index Handles
auto lbs = vLHS.getLbs();
auto ubs = vLHS.getUbs();
// Loop induction vars
auto ivs = IndexHandle::makeIndexHandles(vLHS.rank());
auto pivs = IndexHandle::makeIndexHandlePointers(ivs);
// Steps
auto steps = vLHS.getSteps();
// clang-format off
LoopNestBuilder(pivs, lbs, ubs, steps)(
// single stmt body
[&] {
iRes(ivs) = iLHS(ivs) + iRHS(ivs);
});
// clang-format on
rewriter.replaceOp(op, {result});
REWRITER(NGLessOp)
{
lower_binary_elementwise<mlir::NGLessOp>(op, operands, rewriter, m_pass);
return matchSuccess();
}
REWRITER(NGMaxOp)
{
lower_binary_elementwise<mlir::NGMaxOp>(op, operands, rewriter, m_pass);
return matchSuccess();
}
REWRITER(NGMinOp)
{
lower_binary_elementwise<mlir::NGMinOp>(op, operands, rewriter, m_pass);
return matchSuccess(); return matchSuccess();
} }
...@@ -541,6 +555,81 @@ namespace ...@@ -541,6 +555,81 @@ namespace
#undef REWRITER #undef REWRITER
template <typename OP>
void lower_binary_elementwise(Operation* op,
ArrayRef<Value*> operands,
PatternRewriter& rewriter,
DialectLoweringPass& m_pass)
{
auto loc = cast<OP>(op).getLoc();
auto result = m_pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result->getType().isa<MemRefType>());
// get new operands
Value* lhs = operands[0];
Value* rhs = operands[1];
ScopedContext scope(rewriter, loc);
// Views
MemRefView vRes(result), vLHS(lhs), vRHS(rhs);
// Index Values
IndexedValue iRes(result), iLHS(lhs), iRHS(rhs);
// Bounds Index Handles
auto lbs = vLHS.getLbs();
auto ubs = vLHS.getUbs();
// Loop induction vars
auto ivs = IndexHandle::makeIndexHandles(vLHS.rank());
auto pivs = IndexHandle::makeIndexHandlePointers(ivs);
// Steps
auto steps = vLHS.getSteps();
LoopNestBuilder(pivs, lbs, ubs, steps)(
// single stmt body
[&] {
if (isa<NGAddOp>(op))
{
iRes(ivs) = iLHS(ivs) + iRHS(ivs);
}
else if (isa<NGSubOp>(op))
{
iRes(ivs) = iLHS(ivs) - iRHS(ivs);
}
else if (isa<NGMulOp>(op))
{
iRes(ivs) = iLHS(ivs) * iRHS(ivs);
}
else if (isa<NGDivOp>(op))
{
iRes(ivs) = iLHS(ivs) / iRHS(ivs);
}
else if (isa<NGGreaterOp>(op))
{
iRes(ivs) = ValueHandle(iLHS(ivs)) > ValueHandle(iRHS(ivs));
}
else if (isa<NGLessOp>(op))
{
iRes(ivs) = ValueHandle(iLHS(ivs)) < ValueHandle(iRHS(ivs));
}
else if (isa<NGMaxOp>(op))
{
iRes(ivs) =
edsc::intrinsics::select(ValueHandle(iLHS(ivs)) > ValueHandle(iRHS(ivs)),
ValueHandle(iLHS(ivs)),
ValueHandle(iRHS(ivs)));
}
else if (isa<NGMinOp>(op))
{
iRes(ivs) =
edsc::intrinsics::select(ValueHandle(iLHS(ivs)) < ValueHandle(iRHS(ivs)),
ValueHandle(iLHS(ivs)),
ValueHandle(iRHS(ivs)));
}
else
{
NGRAPH_CHECK(false, "Unsupported op");
}
});
rewriter.replaceOp(op, {result});
}
template <typename RedOp> template <typename RedOp>
void lowerIndexReduction(Operation* op, void lowerIndexReduction(Operation* op,
ArrayRef<Value*> operands, ArrayRef<Value*> operands,
......
...@@ -26,8 +26,15 @@ ...@@ -26,8 +26,15 @@
MLIR_OP(NGAddOp) MLIR_OP(NGAddOp)
MLIR_OP(NGArgMaxRedOp) MLIR_OP(NGArgMaxRedOp)
MLIR_OP(NGArgMinRedOp) MLIR_OP(NGArgMinRedOp)
MLIR_OP(NGDivOp)
MLIR_OP(NGDotOp) MLIR_OP(NGDotOp)
MLIR_OP(NGGreaterOp)
MLIR_OP(NGLessOp)
MLIR_OP(NGMulOp)
MLIR_OP(NGMaxOp)
MLIR_OP(NGMinOp)
MLIR_OP(NGReluOp) MLIR_OP(NGReluOp)
MLIR_OP(NGSubOp)
MLIR_LAST_OP(NGReturnOp) MLIR_LAST_OP(NGReturnOp)
#undef MLIR_OP #undef MLIR_OP
......
...@@ -6,8 +6,15 @@ ...@@ -6,8 +6,15 @@
MLIR_OP(Add) MLIR_OP(Add)
MLIR_OP(ArgMin) MLIR_OP(ArgMin)
MLIR_OP(ArgMax) MLIR_OP(ArgMax)
MLIR_OP(Relu) MLIR_OP(Divide)
MLIR_OP(Dot) MLIR_OP(Dot)
MLIR_OP(Greater)
MLIR_OP(Less)
MLIR_OP(Maximum)
MLIR_OP(Minimum)
MLIR_OP(Multiply)
MLIR_OP(Subtract)
MLIR_OP(Relu)
// Add new supported ops here // Add new supported ops here
#undef MLIR_OP #undef MLIR_OP
...@@ -21,10 +21,17 @@ ...@@ -21,10 +21,17 @@
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/argmax.hpp" #include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp" #include "ngraph/op/argmin.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp" #include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp" #include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/relu.hpp" #include "ngraph/op/relu.hpp"
#include "ngraph/op/subtract.hpp"
using namespace ngraph::descriptor; using namespace ngraph::descriptor;
using namespace ngraph::op; using namespace ngraph::op;
...@@ -275,13 +282,49 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node ...@@ -275,13 +282,49 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
{ {
return false; return false;
} }
else
{
return true;
}
} }
if (TI(ngraph::op::ArgMin) == TI(*node) || TI(ngraph::op::ArgMax) == TI(*node)) if (TI(ngraph::op::ArgMin) == TI(*node) || TI(ngraph::op::ArgMax) == TI(*node))
{ {
// TODO: Remove this when MLIR has float point cmp support // TODO: Remove this when MLIR has float point cmp support
if (!node->input(0).get_element_type().is_integral()) if (!node->input(0).get_element_type().is_integral())
{
return false;
}
else
{
return true;
}
}
if (TI(ngraph::op::Maximum) == TI(*node) || TI(ngraph::op::Minimum) == TI(*node))
{
// TODO: Remove this when MLIR has float point cmp support
if (!node->input(0).get_element_type().is_integral())
{
return false;
}
else
{
return true;
}
}
if (TI(ngraph::op::Greater) == TI(*node) || TI(ngraph::op::Less) == TI(*node))
{
// TODO: Remove this when MLIR has float point cmp support
if (!node->input(0).get_element_type().is_integral())
{
return false; return false;
}
else
{
return true;
}
} }
// Relu is supported for integer types only until MLIR adds support for lowering !std.CmpF to LLVM dialect // Relu is supported for integer types only until MLIR adds support for lowering !std.CmpF to LLVM dialect
...@@ -291,6 +334,10 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node ...@@ -291,6 +334,10 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
{ {
return false; return false;
} }
else
{
return true;
}
} }
return true; return true;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment