Commit 4504a5e2 authored by Nagy Mostafa's avatar Nagy Mostafa Committed by nmostafa

[MLIR] Compile function only if all ops are supported by MLIR backend (#35)

MLIR Compile only supported ops
parent 8bb48c81
......@@ -241,7 +241,7 @@ void MLIRCompiler::build_ng_dialect()
{
const NodeVector& sub_graph = m_compiled_kernel->get_node_list();
for(auto np : sub_graph)
for (auto np : sub_graph)
{
auto it = op_dispatcher.find(TI(*np));
if (it == op_dispatcher.end())
......@@ -273,8 +273,9 @@ mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Dot)
}
const MLIRCompiler::MLIRCompOpMap MLIRCompiler::op_dispatcher{
{TI(ngraph::op::Add), &MLIRCompiler::create_op<ngraph::op::Add>},
{TI(ngraph::op::Dot), &MLIRCompiler::create_op<ngraph::op::Dot>}};
#define MLIR_OP(OP) {TI(ngraph::op::OP), &MLIRCompiler::create_op<ngraph::op::OP>},
#include "ops_supported.inc"
};
template <typename BinOp>
mlir::Value* MLIRCompiler::create_binary_op(const ngraph::Node* ng_node)
......
// List of all ops supported by MLIR backend end-to-end
#ifndef MLIR_OP
#define MLIR_OP
#endif
MLIR_OP(Add)
MLIR_OP(Dot)
// Add new supported ops here
#undef MLIR_OP
\ No newline at end of file
......@@ -15,8 +15,9 @@
//*****************************************************************************
#include "mlir_subgraph_extraction.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/get_output_element.hpp"
......@@ -32,6 +33,10 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
NodeVector ck_ops;
for (auto op : func->get_ordered_ops())
{
// All ops must be supported by MLIR compiler
if (!is_supported_mlir_op(op))
return false;
if (TI(Parameter) != TI(*op) && TI(Result) != TI(*op))
{
ck_ops.push_back(op);
......@@ -45,7 +50,8 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
}
NodeVector ck_outputs = std::move(get_subgraph_outputs(ck_ops, {} /*exclusions*/));
NGRAPH_ASSERT(ck_outputs.size() == 1) << "Unsupported subgraph with multiple outputs";
if (ck_outputs.size() != 1)
return false;
auto ck = std::make_shared<CompiledKernel>(ck_ops, ck_outputs, ck_args);
......@@ -68,3 +74,30 @@ bool MLIRSubgraphExtractionPass::run_on_function(std::shared_ptr<Function> func)
return true;
}
#define TI(x) std::type_index(typeid(x))
bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node)
{
if (TI(Parameter) == TI(*node) || TI(Result) == TI(*node))
return true;
// supported by backend ?
if (m_supported_ops.find(TI(*node)) == m_supported_ops.end())
return false;
// check on invariants expected by MLIR backend
// Dot is 2D only
if (TI(ngraph::op::Dot) == TI(*node))
{
if (node->get_input_shape(0).size() != 2 || node->get_input_shape(1).size() != 2)
return false;
}
return true;
}
const std::set<std::type_index> MLIRSubgraphExtractionPass::m_supported_ops{
#define MLIR_OP(OP) TI(ngraph::op::OP),
#include "contrib/mlir/ops_supported.inc"
};
......@@ -31,6 +31,12 @@ namespace ngraph
public:
MLIRSubgraphExtractionPass() {}
bool run_on_function(std::shared_ptr<Function> func) override;
/// Checks if an ngraph node is supported by MLIR backend
/// Currently this check is only valid for CPU backend.
bool is_supported_mlir_op(std::shared_ptr<Node> node);
private:
static const std::set<std::type_index> m_supported_ops;
};
}
}
......@@ -64,7 +64,7 @@ NGRAPH_TEST(${BACKEND_NAME}, dot_add)
Shape shape_out{2, 3};
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
auto B = make_shared<op::Parameter>(element::f32, shape_in2);
auto dot = make_shared<op::Dot>(A, B);
auto dot = make_shared<op::Dot>(A, B);
auto C = make_shared<op::Parameter>(element::f32, shape_out);
auto add = make_shared<op::Add>(dot, C);
auto f = make_shared<Function>(add, ParameterVector{A, B, C});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment