Commit 4504a5e2 authored by Nagy Mostafa's avatar Nagy Mostafa Committed by nmostafa

[MLIR] Compile function only if all ops are supported by MLIR backend (#35)

MLIR Compile only supported ops
parent 8bb48c81
...@@ -241,7 +241,7 @@ void MLIRCompiler::build_ng_dialect() ...@@ -241,7 +241,7 @@ void MLIRCompiler::build_ng_dialect()
{ {
const NodeVector& sub_graph = m_compiled_kernel->get_node_list(); const NodeVector& sub_graph = m_compiled_kernel->get_node_list();
for(auto np : sub_graph) for (auto np : sub_graph)
{ {
auto it = op_dispatcher.find(TI(*np)); auto it = op_dispatcher.find(TI(*np));
if (it == op_dispatcher.end()) if (it == op_dispatcher.end())
...@@ -273,8 +273,9 @@ mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Dot) ...@@ -273,8 +273,9 @@ mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Dot)
} }
const MLIRCompiler::MLIRCompOpMap MLIRCompiler::op_dispatcher{ const MLIRCompiler::MLIRCompOpMap MLIRCompiler::op_dispatcher{
{TI(ngraph::op::Add), &MLIRCompiler::create_op<ngraph::op::Add>}, #define MLIR_OP(OP) {TI(ngraph::op::OP), &MLIRCompiler::create_op<ngraph::op::OP>},
{TI(ngraph::op::Dot), &MLIRCompiler::create_op<ngraph::op::Dot>}}; #include "ops_supported.inc"
};
template <typename BinOp> template <typename BinOp>
mlir::Value* MLIRCompiler::create_binary_op(const ngraph::Node* ng_node) mlir::Value* MLIRCompiler::create_binary_op(const ngraph::Node* ng_node)
......
// List of all ops supported by MLIR backend end-to-end
#ifndef MLIR_OP
#define MLIR_OP
#endif
MLIR_OP(Add)
MLIR_OP(Dot)
// Add new supported ops here
#undef MLIR_OP
\ No newline at end of file
...@@ -15,8 +15,9 @@ ...@@ -15,8 +15,9 @@
//***************************************************************************** //*****************************************************************************
#include "mlir_subgraph_extraction.hpp" #include "mlir_subgraph_extraction.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp" #include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
...@@ -32,6 +33,10 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func) ...@@ -32,6 +33,10 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
NodeVector ck_ops; NodeVector ck_ops;
for (auto op : func->get_ordered_ops()) for (auto op : func->get_ordered_ops())
{ {
// All ops must be supported by MLIR compiler
if (!is_supported_mlir_op(op))
return false;
if (TI(Parameter) != TI(*op) && TI(Result) != TI(*op)) if (TI(Parameter) != TI(*op) && TI(Result) != TI(*op))
{ {
ck_ops.push_back(op); ck_ops.push_back(op);
...@@ -45,7 +50,8 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func) ...@@ -45,7 +50,8 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
} }
NodeVector ck_outputs = std::move(get_subgraph_outputs(ck_ops, {} /*exclusions*/)); NodeVector ck_outputs = std::move(get_subgraph_outputs(ck_ops, {} /*exclusions*/));
NGRAPH_ASSERT(ck_outputs.size() == 1) << "Unsupported subgraph with multiple outputs"; if (ck_outputs.size() != 1)
return false;
auto ck = std::make_shared<CompiledKernel>(ck_ops, ck_outputs, ck_args); auto ck = std::make_shared<CompiledKernel>(ck_ops, ck_outputs, ck_args);
...@@ -68,3 +74,30 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func) ...@@ -68,3 +74,30 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
return true; return true;
} }
#define TI(x) std::type_index(typeid(x))
bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node)
{
if (TI(Parameter) == TI(*node) || TI(Result) == TI(*node))
return true;
// supported by backend ?
if (m_supported_ops.find(TI(*node)) == m_supported_ops.end())
return false;
// check on invariants expected by MLIR backend
// Dot is 2D only
if (TI(ngraph::op::Dot) == TI(*node))
{
if (node->get_input_shape(0).size() != 2 || node->get_input_shape(1).size() != 2)
return false;
}
return true;
}
const std::set<std::type_index> MLIRSubgraphExtractionPass::m_supported_ops{
#define MLIR_OP(OP) TI(ngraph::op::OP),
#include "contrib/mlir/ops_supported.inc"
};
...@@ -31,6 +31,12 @@ namespace ngraph ...@@ -31,6 +31,12 @@ namespace ngraph
public: public:
MLIRSubgraphExtractionPass() {} MLIRSubgraphExtractionPass() {}
bool run_on_function(std::shared_ptr<Function> func) override; bool run_on_function(std::shared_ptr<Function> func) override;
/// Checks if an ngraph node is supported by MLIR backend
/// Currently this check is only valid for CPU backend.
bool is_supported_mlir_op(std::shared_ptr<Node> node);
private:
static const std::set<std::type_index> m_supported_ops;
}; };
} }
} }
...@@ -64,7 +64,7 @@ NGRAPH_TEST(${BACKEND_NAME}, dot_add) ...@@ -64,7 +64,7 @@ NGRAPH_TEST(${BACKEND_NAME}, dot_add)
Shape shape_out{2, 3}; Shape shape_out{2, 3};
auto A = make_shared<op::Parameter>(element::f32, shape_in1); auto A = make_shared<op::Parameter>(element::f32, shape_in1);
auto B = make_shared<op::Parameter>(element::f32, shape_in2); auto B = make_shared<op::Parameter>(element::f32, shape_in2);
auto dot = make_shared<op::Dot>(A, B); auto dot = make_shared<op::Dot>(A, B);
auto C = make_shared<op::Parameter>(element::f32, shape_out); auto C = make_shared<op::Parameter>(element::f32, shape_out);
auto add = make_shared<op::Add>(dot, C); auto add = make_shared<op::Add>(dot, C);
auto f = make_shared<Function>(add, ParameterVector{A, B, C}); auto f = make_shared<Function>(add, ParameterVector{A, B, C});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment