Commit 443cbc8a authored by nmostafa's avatar nmostafa

Return Operation* from create_op

parent a651f4d9
...@@ -284,11 +284,20 @@ void MLIRCompiler::build_ng_dialect() ...@@ -284,11 +284,20 @@ void MLIRCompiler::build_ng_dialect()
throw unsupported_op{std::string{"The MLIR backend doesn't currently implement the '"} + throw unsupported_op{std::string{"The MLIR backend doesn't currently implement the '"} +
np->description() + "' operation"}; np->description() + "' operation"};
} }
mlir::Value* mlir_value = it->second(*this, np.get()); mlir::Operation* op = it->second(*this, np.get());
// builders that have multiple result values will update the value map, and set their ret values to null // This assumes simple 1:1 mapping between output edges and generated MLIR op results
if (mlir_value) // If the mapping is more complex, the create_op helper can return null operation
// and handles populating the value map itself
if (op)
{ {
update_tensor_value(np->get_output_tensor_ptr().get(), mlir_value); for (auto i = 0; i < op->getNumResults(); i++)
{
mlir::Value* result = op->getResult(i);
if (result)
{
update_tensor_value(np->get_output_tensor_ptr(i).get(), result);
}
}
} }
} }
create_return(); create_return();
...@@ -301,97 +310,97 @@ namespace ngraph ...@@ -301,97 +310,97 @@ namespace ngraph
namespace ngmlir namespace ngmlir
{ {
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Add) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Add)
{ {
return compiler.create_generic_op<mlir::NGAddOp>(ng_node).getResult(); return compiler.create_generic_op<mlir::NGAddOp>(ng_node);
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Subtract) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Subtract)
{ {
return compiler.create_generic_op<mlir::NGSubOp>(ng_node).getResult(); return compiler.create_generic_op<mlir::NGSubOp>(ng_node);
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Multiply) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Multiply)
{ {
return compiler.create_generic_op<mlir::NGMulOp>(ng_node).getResult(); return compiler.create_generic_op<mlir::NGMulOp>(ng_node);
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Divide) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Divide)
{ {
return compiler.create_generic_op<mlir::NGDivOp>(ng_node).getResult(); return compiler.create_generic_op<mlir::NGDivOp>(ng_node);
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Greater) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Greater)
{ {
return compiler.create_generic_op<mlir::NGGreaterOp>(ng_node).getResult(); return compiler.create_generic_op<mlir::NGGreaterOp>(ng_node);
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Less) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Less)
{ {
return compiler.create_generic_op<mlir::NGLessOp>(ng_node).getResult(); return compiler.create_generic_op<mlir::NGLessOp>(ng_node);
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Maximum) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Maximum)
{ {
return compiler.create_generic_op<mlir::NGMaxOp>(ng_node).getResult(); return compiler.create_generic_op<mlir::NGMaxOp>(ng_node);
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Minimum) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Minimum)
{ {
return compiler.create_generic_op<mlir::NGMinOp>(ng_node).getResult(); return compiler.create_generic_op<mlir::NGMinOp>(ng_node);
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMax) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMax)
{ {
return compiler.create_index_reduction<mlir::NGArgMaxRedOp>(ng_node); return compiler.create_index_reduction<mlir::NGArgMaxRedOp>(ng_node);
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMin) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMin)
{ {
return compiler.create_index_reduction<mlir::NGArgMinRedOp>(ng_node); return compiler.create_index_reduction<mlir::NGArgMinRedOp>(ng_node);
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Dot) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Dot)
{ {
return compiler.create_generic_op<mlir::NGDotOp>(ng_node).getResult(); return compiler.create_generic_op<mlir::NGDotOp>(ng_node);
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Concat) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Concat)
{ {
auto ng_node_concat = static_cast<const ngraph::op::Concat*>(ng_node); auto ng_node_concat = static_cast<const ngraph::op::Concat*>(ng_node);
auto op = compiler.create_generic_op<mlir::NGConcatOp>(ng_node); auto op = compiler.create_generic_op<mlir::NGConcatOp>(ng_node);
op.setAttr("concatenation_axis", op->setAttr("concatenation_axis",
compiler.m_builder->getI64IntegerAttr( compiler.m_builder->getI64IntegerAttr(
ng_node_concat->get_concatenation_axis())); ng_node_concat->get_concatenation_axis()));
return op.getResult(); return op;
} }
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Gather) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Gather)
{ {
auto ng_node_gather = static_cast<const ngraph::op::Gather*>(ng_node); auto ng_node_gather = static_cast<const ngraph::op::Gather*>(ng_node);
auto op = compiler.create_generic_op<mlir::NGGatherOp>(ng_node); auto op = compiler.create_generic_op<mlir::NGGatherOp>(ng_node);
op.setAttr("axis", op->setAttr("axis",
compiler.m_builder->getI64IntegerAttr(ng_node_gather->get_axis())); compiler.m_builder->getI64IntegerAttr(ng_node_gather->get_axis()));
return op.getResult(); return op;
} }
} }
} }
} }
template <typename Op> template <typename Op>
Op MLIRCompiler::create_generic_op(const ngraph::Node* ng_node) mlir::Operation* MLIRCompiler::create_generic_op(const ngraph::Node* ng_node)
{ {
std::vector<mlir::Value*> arg_values; std::vector<mlir::Value*> arg_values;
std::vector<mlir::Type> res_types; std::vector<mlir::Type> res_types;
...@@ -407,9 +416,9 @@ Op MLIRCompiler::create_generic_op(const ngraph::Node* ng_node) ...@@ -407,9 +416,9 @@ Op MLIRCompiler::create_generic_op(const ngraph::Node* ng_node)
res_types.push_back(get_mlir_type(output.get_tensor_ptr().get())); res_types.push_back(get_mlir_type(output.get_tensor_ptr().get()));
} }
return m_builder return (m_builder
->create<Op, ArrayRef<mlir::Type>, ArrayRef<mlir::Value*>, ArrayRef<mlir::NamedAttribute>>( ->create<Op, ArrayRef<mlir::Type>, ArrayRef<mlir::Value*>, ArrayRef<mlir::NamedAttribute>>(
mlir::UnknownLoc::get(&m_context), res_types, arg_values, {/* no attrs */}); mlir::UnknownLoc::get(&m_context), res_types, arg_values, {/* no attrs */})).getOperation();
} }
const MLIRCompiler::MLIRCompOpMap MLIRCompiler::op_dispatcher{ const MLIRCompiler::MLIRCompOpMap MLIRCompiler::op_dispatcher{
...@@ -428,14 +437,14 @@ void MLIRCompiler::create_return() ...@@ -428,14 +437,14 @@ void MLIRCompiler::create_return()
} }
template <typename RedOp> template <typename RedOp>
mlir::Value* MLIRCompiler::create_index_reduction(const ngraph::Node* ng_node) mlir::Operation* MLIRCompiler::create_index_reduction(const ngraph::Node* ng_node)
{ {
auto* idx_red = static_cast<const ngraph::op::util::IndexReduction*>(ng_node); auto* idx_red = static_cast<const ngraph::op::util::IndexReduction*>(ng_node);
auto op = create_generic_op<RedOp>(ng_node); auto op = create_generic_op<RedOp>(ng_node);
mlir::ArrayAttr red_axes_attr = mlir::ArrayAttr red_axes_attr =
m_builder->getI64ArrayAttr({(int64_t)idx_red->get_reduction_axis()}); m_builder->getI64ArrayAttr({(int64_t)idx_red->get_reduction_axis()});
op.setAttr("axes", red_axes_attr); op->setAttr("axes", red_axes_attr);
return op.getResult(); return op;
} }
// Binds MLIR function arguments to the proper values. This includes externally allocated tensors // Binds MLIR function arguments to the proper values. This includes externally allocated tensors
......
...@@ -99,7 +99,7 @@ namespace ngraph ...@@ -99,7 +99,7 @@ namespace ngraph
void build_ng_dialect(); void build_ng_dialect();
template <typename Op> template <typename Op>
static mlir::Value* create_op(MLIRCompiler& compiler, const ngraph::Node* ng_node) static mlir::Operation* create_op(MLIRCompiler& compiler, const ngraph::Node* ng_node)
{ {
throw std::runtime_error("Unimplemented op '" + ng_node->description() + throw std::runtime_error("Unimplemented op '" + ng_node->description() +
"' in MLIR Compiler"); "' in MLIR Compiler");
...@@ -108,14 +108,10 @@ namespace ngraph ...@@ -108,14 +108,10 @@ namespace ngraph
// Generic op lowerer to ng dialect. // Generic op lowerer to ng dialect.
// Simply maps ngraph tensors to values and generate an OP. No op-specific logic. // Simply maps ngraph tensors to values and generate an OP. No op-specific logic.
template <typename Op> template <typename Op>
Op create_generic_op(const ngraph::Node* ng_node); mlir::Operation* create_generic_op(const ngraph::Node* ng_node);
// TODO(amprocte): Can we have a create_variadic_op that is able to handle the
// attributes?
mlir::Value* create_concat(const ngraph::Node* ng_node);
template <typename RedOp> template <typename RedOp>
mlir::Value* create_index_reduction(const ngraph::Node* ng_node); mlir::Operation* create_index_reduction(const ngraph::Node* ng_node);
void create_return(); void create_return();
...@@ -149,7 +145,7 @@ namespace ngraph ...@@ -149,7 +145,7 @@ namespace ngraph
using TensorToInfo = std::pair<descriptor::Tensor*, TensorInfo>; using TensorToInfo = std::pair<descriptor::Tensor*, TensorInfo>;
using TensorToInfoMap = std::unordered_map<descriptor::Tensor*, TensorInfo>; using TensorToInfoMap = std::unordered_map<descriptor::Tensor*, TensorInfo>;
using MLIRCompOpFunction = using MLIRCompOpFunction =
std::function<mlir::Value*(MLIRCompiler& compiler, const ngraph::Node*)>; std::function<mlir::Operation*(MLIRCompiler& compiler, const ngraph::Node*)>;
using MLIRCompOpMap = std::unordered_map<std::type_index, MLIRCompOpFunction>; using MLIRCompOpMap = std::unordered_map<std::type_index, MLIRCompOpFunction>;
// Maps tensor to the value it represents in the IR // Maps tensor to the value it represents in the IR
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment