Commit a651f4d9 authored by nmostafa's avatar nmostafa

PR fixes

parent 158a82eb
......@@ -303,109 +303,95 @@ namespace ngraph
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Add)
{
return compiler.create_generic_op<mlir::NGAddOp>(ng_node);
return compiler.create_generic_op<mlir::NGAddOp>(ng_node).getResult();
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Subtract)
{
return compiler.create_generic_op<mlir::NGSubOp>(ng_node);
return compiler.create_generic_op<mlir::NGSubOp>(ng_node).getResult();
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Multiply)
{
return compiler.create_generic_op<mlir::NGMulOp>(ng_node);
return compiler.create_generic_op<mlir::NGMulOp>(ng_node).getResult();
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Divide)
{
return compiler.create_generic_op<mlir::NGDivOp>(ng_node);
return compiler.create_generic_op<mlir::NGDivOp>(ng_node).getResult();
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Greater)
{
return compiler.create_generic_op<mlir::NGGreaterOp>(ng_node);
return compiler.create_generic_op<mlir::NGGreaterOp>(ng_node).getResult();
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Less)
{
return compiler.create_generic_op<mlir::NGLessOp>(ng_node);
return compiler.create_generic_op<mlir::NGLessOp>(ng_node).getResult();
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Maximum)
{
return compiler.create_generic_op<mlir::NGMaxOp>(ng_node);
return compiler.create_generic_op<mlir::NGMaxOp>(ng_node).getResult();
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Minimum)
{
return compiler.create_generic_op<mlir::NGMinOp>(ng_node);
return compiler.create_generic_op<mlir::NGMinOp>(ng_node).getResult();
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Dot)
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMax)
{
return compiler.create_generic_op<mlir::NGDotOp>(ng_node);
return compiler.create_index_reduction<mlir::NGArgMaxRedOp>(ng_node);
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMax)
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMin)
{
auto* idx_red = static_cast<const ngraph::op::util::IndexReduction*>(ng_node);
mlir::Value* result = compiler.create_generic_op<mlir::NGArgMaxRedOp>(ng_node);
mlir::Operation* op = result->getDefiningOp();
mlir::ArrayAttr red_axes_attr =
compiler.m_builder->getI64ArrayAttr({(int64_t)idx_red->get_reduction_axis()});
op->setAttr("axes", red_axes_attr);
return result;
return compiler.create_index_reduction<mlir::NGArgMinRedOp>(ng_node);
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMin)
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Dot)
{
auto* idx_red = static_cast<const ngraph::op::util::IndexReduction*>(ng_node);
mlir::Value* result = compiler.create_generic_op<mlir::NGArgMinRedOp>(ng_node);
mlir::Operation* op = result->getDefiningOp();
mlir::ArrayAttr red_axes_attr =
compiler.m_builder->getI64ArrayAttr({(int64_t)idx_red->get_reduction_axis()});
op->setAttr("axes", red_axes_attr);
return result;
return compiler.create_generic_op<mlir::NGDotOp>(ng_node).getResult();
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Concat)
{
auto ng_node_concat = static_cast<const ngraph::op::Concat*>(ng_node);
mlir::Value* result = compiler.create_generic_op<mlir::NGConcatOp>(ng_node);
mlir::Operation* op = result->getDefiningOp();
op->setAttr("concatenation_axis",
compiler.m_builder->getI64IntegerAttr(
ng_node_concat->get_concatenation_axis()));
return result;
auto op = compiler.create_generic_op<mlir::NGConcatOp>(ng_node);
op.setAttr("concatenation_axis",
compiler.m_builder->getI64IntegerAttr(
ng_node_concat->get_concatenation_axis()));
return op.getResult();
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Gather)
{
auto ng_node_gather = static_cast<const ngraph::op::Gather*>(ng_node);
mlir::Value* result = compiler.create_generic_op<mlir::NGGatherOp>(ng_node);
mlir::Operation* op = result->getDefiningOp();
op->setAttr("axis",
compiler.m_builder->getI64IntegerAttr(ng_node_gather->get_axis()));
return result;
auto op = compiler.create_generic_op<mlir::NGGatherOp>(ng_node);
op.setAttr("axis",
compiler.m_builder->getI64IntegerAttr(ng_node_gather->get_axis()));
return op.getResult();
}
}
}
}
template <typename Op>
mlir::Value* MLIRCompiler::create_generic_op(const ngraph::Node* ng_node)
Op MLIRCompiler::create_generic_op(const ngraph::Node* ng_node)
{
std::vector<mlir::Value*> arg_values;
std::vector<mlir::Type> res_types;
......@@ -423,8 +409,7 @@ mlir::Value* MLIRCompiler::create_generic_op(const ngraph::Node* ng_node)
return m_builder
->create<Op, ArrayRef<mlir::Type>, ArrayRef<mlir::Value*>, ArrayRef<mlir::NamedAttribute>>(
mlir::UnknownLoc::get(&m_context), res_types, arg_values, {/* no attrs */})
.getResult();
mlir::UnknownLoc::get(&m_context), res_types, arg_values, {/* no attrs */});
}
const MLIRCompiler::MLIRCompOpMap MLIRCompiler::op_dispatcher{
......@@ -442,6 +427,17 @@ void MLIRCompiler::create_return()
m_builder->create<mlir::NGReturnOp>(mlir::UnknownLoc::get(&m_context), value_list);
}
template <typename RedOp>
mlir::Value* MLIRCompiler::create_index_reduction(const ngraph::Node* ng_node)
{
auto* idx_red = static_cast<const ngraph::op::util::IndexReduction*>(ng_node);
auto op = create_generic_op<RedOp>(ng_node);
mlir::ArrayAttr red_axes_attr =
m_builder->getI64ArrayAttr({(int64_t)idx_red->get_reduction_axis()});
op.setAttr("axes", red_axes_attr);
return op.getResult();
}
// Binds MLIR function arguments to the proper values. This includes externally allocated tensors
// helpers to be used inside the function.
void MLIRCompiler::bind_arguments()
......
......@@ -107,14 +107,8 @@ namespace ngraph
// Generic op lowerer to ng dialect.
// Simply maps ngraph tensors to values and generate an OP. No op-specific logic.
template <typename OP>
mlir::Value* create_generic_op(const ngraph::Node* ng_node);
template <typename UnaryOp>
mlir::Value* create_unary_op(const ngraph::Node* ng_node);
template <typename BinOp>
mlir::Value* create_binary_op(const ngraph::Node* ng_node);
template <typename Op>
Op create_generic_op(const ngraph::Node* ng_node);
// TODO(amprocte): Can we have a create_variadic_op that is able to handle the
// attributes?
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment