Commit a651f4d9 authored by nmostafa's avatar nmostafa

PR fixes

parent 158a82eb
...@@ -303,109 +303,95 @@ namespace ngraph ...@@ -303,109 +303,95 @@ namespace ngraph
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Add) mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Add)
{ {
return compiler.create_generic_op<mlir::NGAddOp>(ng_node); return compiler.create_generic_op<mlir::NGAddOp>(ng_node).getResult();
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Subtract) mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Subtract)
{ {
return compiler.create_generic_op<mlir::NGSubOp>(ng_node); return compiler.create_generic_op<mlir::NGSubOp>(ng_node).getResult();
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Multiply) mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Multiply)
{ {
return compiler.create_generic_op<mlir::NGMulOp>(ng_node); return compiler.create_generic_op<mlir::NGMulOp>(ng_node).getResult();
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Divide) mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Divide)
{ {
return compiler.create_generic_op<mlir::NGDivOp>(ng_node); return compiler.create_generic_op<mlir::NGDivOp>(ng_node).getResult();
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Greater) mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Greater)
{ {
return compiler.create_generic_op<mlir::NGGreaterOp>(ng_node); return compiler.create_generic_op<mlir::NGGreaterOp>(ng_node).getResult();
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Less) mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Less)
{ {
return compiler.create_generic_op<mlir::NGLessOp>(ng_node); return compiler.create_generic_op<mlir::NGLessOp>(ng_node).getResult();
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Maximum) mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Maximum)
{ {
return compiler.create_generic_op<mlir::NGMaxOp>(ng_node); return compiler.create_generic_op<mlir::NGMaxOp>(ng_node).getResult();
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Minimum) mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Minimum)
{ {
return compiler.create_generic_op<mlir::NGMinOp>(ng_node); return compiler.create_generic_op<mlir::NGMinOp>(ng_node).getResult();
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Dot) mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMax)
{ {
return compiler.create_generic_op<mlir::NGDotOp>(ng_node); return compiler.create_index_reduction<mlir::NGArgMaxRedOp>(ng_node);
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMax) mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMin)
{ {
auto* idx_red = static_cast<const ngraph::op::util::IndexReduction*>(ng_node); return compiler.create_index_reduction<mlir::NGArgMinRedOp>(ng_node);
mlir::Value* result = compiler.create_generic_op<mlir::NGArgMaxRedOp>(ng_node);
mlir::Operation* op = result->getDefiningOp();
mlir::ArrayAttr red_axes_attr =
compiler.m_builder->getI64ArrayAttr({(int64_t)idx_red->get_reduction_axis()});
op->setAttr("axes", red_axes_attr);
return result;
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMin) mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Dot)
{ {
auto* idx_red = static_cast<const ngraph::op::util::IndexReduction*>(ng_node); return compiler.create_generic_op<mlir::NGDotOp>(ng_node).getResult();
mlir::Value* result = compiler.create_generic_op<mlir::NGArgMinRedOp>(ng_node);
mlir::Operation* op = result->getDefiningOp();
mlir::ArrayAttr red_axes_attr =
compiler.m_builder->getI64ArrayAttr({(int64_t)idx_red->get_reduction_axis()});
op->setAttr("axes", red_axes_attr);
return result;
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Concat) mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Concat)
{ {
auto ng_node_concat = static_cast<const ngraph::op::Concat*>(ng_node); auto ng_node_concat = static_cast<const ngraph::op::Concat*>(ng_node);
mlir::Value* result = compiler.create_generic_op<mlir::NGConcatOp>(ng_node); auto op = compiler.create_generic_op<mlir::NGConcatOp>(ng_node);
mlir::Operation* op = result->getDefiningOp(); op.setAttr("concatenation_axis",
op->setAttr("concatenation_axis",
compiler.m_builder->getI64IntegerAttr( compiler.m_builder->getI64IntegerAttr(
ng_node_concat->get_concatenation_axis())); ng_node_concat->get_concatenation_axis()));
return result; return op.getResult();
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Gather) mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Gather)
{ {
auto ng_node_gather = static_cast<const ngraph::op::Gather*>(ng_node); auto ng_node_gather = static_cast<const ngraph::op::Gather*>(ng_node);
mlir::Value* result = compiler.create_generic_op<mlir::NGGatherOp>(ng_node); auto op = compiler.create_generic_op<mlir::NGGatherOp>(ng_node);
mlir::Operation* op = result->getDefiningOp(); op.setAttr("axis",
op->setAttr("axis",
compiler.m_builder->getI64IntegerAttr(ng_node_gather->get_axis())); compiler.m_builder->getI64IntegerAttr(ng_node_gather->get_axis()));
return result; return op.getResult();
} }
} }
} }
} }
template <typename Op> template <typename Op>
mlir::Value* MLIRCompiler::create_generic_op(const ngraph::Node* ng_node) Op MLIRCompiler::create_generic_op(const ngraph::Node* ng_node)
{ {
std::vector<mlir::Value*> arg_values; std::vector<mlir::Value*> arg_values;
std::vector<mlir::Type> res_types; std::vector<mlir::Type> res_types;
...@@ -423,8 +409,7 @@ mlir::Value* MLIRCompiler::create_generic_op(const ngraph::Node* ng_node) ...@@ -423,8 +409,7 @@ mlir::Value* MLIRCompiler::create_generic_op(const ngraph::Node* ng_node)
return m_builder return m_builder
->create<Op, ArrayRef<mlir::Type>, ArrayRef<mlir::Value*>, ArrayRef<mlir::NamedAttribute>>( ->create<Op, ArrayRef<mlir::Type>, ArrayRef<mlir::Value*>, ArrayRef<mlir::NamedAttribute>>(
mlir::UnknownLoc::get(&m_context), res_types, arg_values, {/* no attrs */}) mlir::UnknownLoc::get(&m_context), res_types, arg_values, {/* no attrs */});
.getResult();
} }
const MLIRCompiler::MLIRCompOpMap MLIRCompiler::op_dispatcher{ const MLIRCompiler::MLIRCompOpMap MLIRCompiler::op_dispatcher{
...@@ -442,6 +427,17 @@ void MLIRCompiler::create_return() ...@@ -442,6 +427,17 @@ void MLIRCompiler::create_return()
m_builder->create<mlir::NGReturnOp>(mlir::UnknownLoc::get(&m_context), value_list); m_builder->create<mlir::NGReturnOp>(mlir::UnknownLoc::get(&m_context), value_list);
} }
template <typename RedOp>
mlir::Value* MLIRCompiler::create_index_reduction(const ngraph::Node* ng_node)
{
auto* idx_red = static_cast<const ngraph::op::util::IndexReduction*>(ng_node);
auto op = create_generic_op<RedOp>(ng_node);
mlir::ArrayAttr red_axes_attr =
m_builder->getI64ArrayAttr({(int64_t)idx_red->get_reduction_axis()});
op.setAttr("axes", red_axes_attr);
return op.getResult();
}
// Binds MLIR function arguments to the proper values. This includes externally allocated tensors // Binds MLIR function arguments to the proper values. This includes externally allocated tensors
// helpers to be used inside the function. // helpers to be used inside the function.
void MLIRCompiler::bind_arguments() void MLIRCompiler::bind_arguments()
......
...@@ -107,14 +107,8 @@ namespace ngraph ...@@ -107,14 +107,8 @@ namespace ngraph
// Generic op lowerer to ng dialect. // Generic op lowerer to ng dialect.
// Simply maps ngraph tensors to values and generate an OP. No op-specific logic. // Simply maps ngraph tensors to values and generate an OP. No op-specific logic.
template <typename OP> template <typename Op>
mlir::Value* create_generic_op(const ngraph::Node* ng_node); Op create_generic_op(const ngraph::Node* ng_node);
template <typename UnaryOp>
mlir::Value* create_unary_op(const ngraph::Node* ng_node);
template <typename BinOp>
mlir::Value* create_binary_op(const ngraph::Node* ng_node);
// TODO(amprocte): Can we have a create_variadic_op that is able to handle the // TODO(amprocte): Can we have a create_variadic_op that is able to handle the
// attributes? // attributes?
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment