Commit 3a9de1bb authored by nmostafa's avatar nmostafa

Refactor ng dialect compile

parent c5b976c8
......@@ -30,6 +30,7 @@
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/maximum.hpp"
......@@ -63,6 +64,7 @@
using llvm::SmallVector;
using llvm::StringRef;
using llvm::make_unique;
using llvm::ArrayRef;
using namespace ngraph::runtime::ngmlir;
......@@ -301,131 +303,131 @@ namespace ngraph
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Add)
{
return compiler.create_binary_op<mlir::NGAddOp>(ng_node);
return compiler.create_generic_op<mlir::NGAddOp>(ng_node);
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Subtract)
{
return compiler.create_binary_op<mlir::NGSubOp>(ng_node);
return compiler.create_generic_op<mlir::NGSubOp>(ng_node);
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Multiply)
{
return compiler.create_binary_op<mlir::NGMulOp>(ng_node);
return compiler.create_generic_op<mlir::NGMulOp>(ng_node);
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Divide)
{
return compiler.create_binary_op<mlir::NGDivOp>(ng_node);
return compiler.create_generic_op<mlir::NGDivOp>(ng_node);
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Greater)
{
return compiler.create_binary_op<mlir::NGGreaterOp>(ng_node);
return compiler.create_generic_op<mlir::NGGreaterOp>(ng_node);
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Less)
{
return compiler.create_binary_op<mlir::NGLessOp>(ng_node);
return compiler.create_generic_op<mlir::NGLessOp>(ng_node);
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Maximum)
{
return compiler.create_binary_op<mlir::NGMaxOp>(ng_node);
return compiler.create_generic_op<mlir::NGMaxOp>(ng_node);
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Minimum)
{
return compiler.create_binary_op<mlir::NGMinOp>(ng_node);
return compiler.create_generic_op<mlir::NGMinOp>(ng_node);
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMax)
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Dot)
{
return compiler.create_index_reduction<mlir::NGArgMaxRedOp>(ng_node);
return compiler.create_generic_op<mlir::NGDotOp>(ng_node);
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMin)
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMax)
{
return compiler.create_index_reduction<mlir::NGArgMinRedOp>(ng_node);
auto* idx_red = static_cast<const ngraph::op::util::IndexReduction*>(ng_node);
mlir::Value* result = compiler.create_generic_op<mlir::NGArgMaxRedOp>(ng_node);
mlir::Operation* op = result->getDefiningOp();
mlir::ArrayAttr red_axes_attr = compiler.m_builder->getI64ArrayAttr({(int64_t)idx_red->get_reduction_axis()});
op->setAttr("axes", red_axes_attr);
return result;
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Dot)
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMin)
{
return compiler.create_binary_op<mlir::NGDotOp>(ng_node);
auto* idx_red = static_cast<const ngraph::op::util::IndexReduction*>(ng_node);
mlir::Value* result = compiler.create_generic_op<mlir::NGArgMinRedOp>(ng_node);
mlir::Operation* op = result->getDefiningOp();
mlir::ArrayAttr red_axes_attr = compiler.m_builder->getI64ArrayAttr({(int64_t)idx_red->get_reduction_axis()});
op->setAttr("axes", red_axes_attr);
return result;
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Concat)
{
return compiler.create_concat(ng_node);
auto ng_node_concat = static_cast<const ngraph::op::Concat*>(ng_node);
mlir::Value* result = compiler.create_generic_op<mlir::NGConcatOp>(ng_node);
mlir::Operation* op = result->getDefiningOp();
op->setAttr("concatenation_axis", compiler.m_builder->getI64IntegerAttr(ng_node_concat->get_concatenation_axis()));
return result;
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Relu)
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Gather)
{
return compiler.create_unary_op<mlir::NGReluOp>(ng_node);
return nullptr; //compiler.create_gather(ng_node);
}
}
}
}
const MLIRCompiler::MLIRCompOpMap MLIRCompiler::op_dispatcher{
#define MLIR_OP(OP) {TI(ngraph::op::OP), &MLIRCompiler::create_op<ngraph::op::OP>},
#include "ops_supported.inc"
};
template <typename UnaryOp>
mlir::Value* MLIRCompiler::create_unary_op(const ngraph::Node* ng_node)
{
auto lhs = ng_node->get_argument(0)->get_output_tensor_ptr();
auto lhs_v = get_tensor_value(lhs.get()).m_value;
auto res_type = get_mlir_type(ng_node->get_output_tensor_ptr().get());
return m_builder->create<UnaryOp>(mlir::UnknownLoc::get(&m_context), res_type, lhs_v)
.getResult();
}
template <typename BinOp>
mlir::Value* MLIRCompiler::create_binary_op(const ngraph::Node* ng_node)
{
auto lhs = ng_node->get_argument(0)->get_output_tensor_ptr();
auto rhs = ng_node->get_argument(1)->get_output_tensor_ptr();
auto lhs_v = get_tensor_value(lhs.get()).m_value;
auto rhs_v = get_tensor_value(rhs.get()).m_value;
auto res_type = get_mlir_type(ng_node->get_output_tensor_ptr().get());
return m_builder->create<BinOp>(mlir::UnknownLoc::get(&m_context), res_type, lhs_v, rhs_v)
.getResult();
}
mlir::Value* MLIRCompiler::create_concat(const ngraph::Node* ng_node)
template <typename Op>
mlir::Value* MLIRCompiler::create_generic_op(const ngraph::Node* ng_node)
{
std::vector<mlir::Value*> arg_values;
auto ng_node_concat = static_cast<const ngraph::op::Concat*>(ng_node);
std::vector<mlir::Type> res_types;
for (auto& arg : ng_node->get_arguments())
{
auto arg_tensor = arg->get_output_tensor_ptr();
auto arg_v = get_tensor_value(arg_tensor.get()).m_value;
arg_values.push_back(arg_v);
}
auto res_type = get_mlir_type(ng_node->get_output_tensor_ptr().get());
for (auto& output : ng_node->outputs())
{
res_types.push_back(get_mlir_type(output.get_tensor_ptr().get()));
}
return m_builder
->create<mlir::NGConcatOp>(
mlir::UnknownLoc::get(&m_context),
res_type,
arg_values,
m_builder->getI64IntegerAttr(ng_node_concat->get_concatenation_axis()))
.getResult();
->create<Op,
ArrayRef<mlir::Type>,
ArrayRef<mlir::Value *>,
ArrayRef<mlir::NamedAttribute>>(
mlir::UnknownLoc::get(&m_context),
res_types,
arg_values, {/* no attrs */}).getResult();
}
const MLIRCompiler::MLIRCompOpMap MLIRCompiler::op_dispatcher{
#define MLIR_OP(OP) {TI(ngraph::op::OP), &MLIRCompiler::create_op<ngraph::op::OP>},
#include "ops_supported.inc"
};
void MLIRCompiler::create_return()
{
std::vector<mlir::Value*> value_list;
......@@ -436,22 +438,6 @@ void MLIRCompiler::create_return()
m_builder->create<mlir::NGReturnOp>(mlir::UnknownLoc::get(&m_context), value_list);
}
template <typename RedOp>
mlir::Value* MLIRCompiler::create_index_reduction(const ngraph::Node* ng_node)
{
auto* idx_red = static_cast<const ngraph::op::util::IndexReduction*>(ng_node);
auto arg = idx_red->get_argument(0);
size_t red_axis = idx_red->get_reduction_axis();
mlir::Value* arg_val = get_tensor_value(arg->get_output_tensor_ptr().get()).m_value;
mlir::ArrayAttr red_axes_attr = m_builder->getI64ArrayAttr({(int64_t)red_axis});
return m_builder
->create<RedOp>(
mlir::UnknownLoc::get(&m_context), get_mlir_type(ng_node), arg_val, red_axes_attr)
.getResult();
}
// Binds MLIR function arguments to the proper values. This includes externally allocated tensors
// helpers to be used inside the function.
void MLIRCompiler::bind_arguments()
......
......@@ -98,13 +98,18 @@ namespace ngraph
void build_ng_dialect();
template <typename OP>
template<typename Op>
static mlir::Value* create_op(MLIRCompiler& compiler, const ngraph::Node* ng_node)
{
throw std::runtime_error("Unimplemented op '" + ng_node->description() +
"' in MLIR Compiler");
}
// Generic op lowerer to ng dialect.
// Simply maps ngraph tensors to values and generate an OP. No op-specific logic.
template <typename OP>
mlir::Value* create_generic_op(const ngraph::Node* ng_node);
template <typename UnaryOp>
mlir::Value* create_unary_op(const ngraph::Node* ng_node);
......
......@@ -168,6 +168,39 @@ static mlir::LogicalResult verifyCmpOp(T* op)
return mlir::success();
}
template <>
mlir::LogicalResult verifyOp(NGGatherOp* op)
{
Type ty = op->input()->getType();
NGTensorType inputType = ty.cast<NGTensorType>();
ty = op->input()->getType();
NGTensorType indicesType = ty.cast<NGTensorType>();
// ensure axis < params rank
if (op->axis().getSExtValue() >= inputType.getRank());
return op->emitOpError("Gather axis is larger than input rank");
ty = indicesType.getElementType();
// ensure indices are I32 or I64
if (!ty.isa<NGIntegerType>())
return op->emitOpError("Indices tensor is not of Integer type");
NGIntegerType indicesEltType = ty.cast<NGIntegerType>();
if (!indicesEltType.isInt32() && !indicesEltType.isInt64())
return op->emitOpError("Indices tensor is not of I32 or I64 type");
mlir::Type r0 = op->res()->getType();
NGTensorType resType = r0.cast<NGTensorType>();
// ensure result is compatible with input
if (!resType.isCompatible(inputType))
return op->emitOpError("Incompatible result shape and/or type");
return mlir::success();
}
namespace mlir
{
#define GET_OP_CLASSES
......
......@@ -186,8 +186,8 @@ def NGDotOp : NG_Binary_Op<"dot">
// class, but I'm not sure how to add concatenation_axis into the args if we
// do that.
def NGConcatOp :
NG_OneResult_Op<"concat", [NoSideEffect]>,
Arguments<(ins Variadic<NG_TensorType>:$args, I64Attr:$concatenation_axis)>
NG_OneResult_Op<"concat", [NoSideEffect]>,
Arguments<(ins Variadic<NG_TensorType>:$args, I64Attr:$concatenation_axis)>
{
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
......@@ -200,7 +200,7 @@ class NG_Axis_Reduction_Op<string mnemonic, list<OpTrait> traits = []> :
{
let summary = "Base class for reduction operations that perform a reduction "
"across the axes of a single tensor.";
let description = "Axes are represented as an array of I64 attributes.";
let description = [{Axes are represented as an array of I64 attributes.}];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
......@@ -257,6 +257,24 @@ def NGAnyRedOp : NG_Axis_Reduction_Op<"any.red">
let verifier = [{ return verifyLogicalReductionOp(this); }];
}
// Gather
def NGGatherOp :
NG_OneResult_Op<"gather", [NoSideEffect]>,
Arguments<(ins NG_TensorType:$input, NG_TensorType:$indices, I64Attr:$axis)>
{
let summary = "Gather slices from input along the specified axis according to indices";
let description = [{
Gather slices from axis of input according to indices
input The tensor from which slices are gathered
indices Index tensor: Data type must be `element::i32` or `element::i64`
axis Axis in input to gather
}];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }];
}
// Terminator Ops
def NGReturnOp : NG_Terminator_Op<"return">;
......
......@@ -199,6 +199,7 @@ namespace mlir
}
Shape getShape() const { return m_shape; }
int64_t getRank() const { return m_shape.size(); }
EltType getElementType() const { return m_eltType; }
private:
NGTensorTypeStorage(EltType eltType, Shape shape)
......
......@@ -653,8 +653,13 @@ namespace
return matchSuccess();
}
#undef REWRITER
REWRITER(NGGatherOp)
{
return matchSuccess();
}
#undef REWRITER
/// End of pattern matchers
template <typename OP>
void lower_binary_elementwise(Operation* op,
ArrayRef<Value*> operands,
......
......@@ -29,6 +29,7 @@ MLIR_OP(NGArgMinRedOp)
MLIR_OP(NGConcatOp)
MLIR_OP(NGDivOp)
MLIR_OP(NGDotOp)
MLIR_OP(NGGatherOp)
MLIR_OP(NGGreaterOp)
MLIR_OP(NGLessOp)
MLIR_OP(NGMulOp)
......
......@@ -9,6 +9,7 @@ MLIR_OP(ArgMax)
MLIR_OP(Divide)
MLIR_OP(Dot)
MLIR_OP(Concat)
MLIR_OP(Gather)
MLIR_OP(Greater)
MLIR_OP(Less)
MLIR_OP(Maximum)
......
......@@ -25,6 +25,7 @@
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/less.hpp"
......@@ -38,6 +39,7 @@ using namespace ngraph::descriptor;
using namespace ngraph::op;
using namespace ngraph::pass;
#define TI(x) std::type_index(typeid(x))
int MLIRSubgraphExtractionPass::MLIRSubgraph::m_curr_graph_id = 0;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment