Commit c6e9747b authored by nmostafa's avatar nmostafa

Move CompiledKernel back to ngraph core. Add a CompiledKernel->MLIRCompiler map in RuntimeContext

parent bd15a4d9
...@@ -23,8 +23,6 @@ set(SRC ...@@ -23,8 +23,6 @@ set(SRC
memory_manager.cpp memory_manager.cpp
pass/mlir_subgraph_extraction.cpp pass/mlir_subgraph_extraction.cpp
pass/mlir_subgraph_extraction.hpp pass/mlir_subgraph_extraction.hpp
compiled_kernel.cpp
compiled_kernel.hpp
) )
if (NGRAPH_MLIR_ENABLE) if (NGRAPH_MLIR_ENABLE)
......
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
#include "ngraph/op/concat.hpp" #include "ngraph/op/concat.hpp"
#include "ngraph/op/divide.hpp" #include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp" #include "ngraph/op/dot.hpp"
#include "compiled_kernel.hpp" #include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/gather.hpp" #include "ngraph/op/gather.hpp"
#include "ngraph/op/greater.hpp" #include "ngraph/op/greater.hpp"
#include "ngraph/op/less.hpp" #include "ngraph/op/less.hpp"
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include "ngraph/op/concat.hpp" #include "ngraph/op/concat.hpp"
#include "ngraph/op/divide.hpp" #include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp" #include "ngraph/op/dot.hpp"
#include "contrib/mlir/compiled_kernel.hpp" #include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/gather.hpp" #include "ngraph/op/gather.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp" #include "ngraph/op/greater.hpp"
......
...@@ -174,6 +174,8 @@ set (SRC ...@@ -174,6 +174,8 @@ set (SRC
op/experimental/quantized_dot.hpp op/experimental/quantized_dot.hpp
op/experimental/quantized_dot_bias.cpp op/experimental/quantized_dot_bias.cpp
op/experimental/quantized_dot_bias.hpp op/experimental/quantized_dot_bias.hpp
op/experimental/compiled_kernel.cpp
op/experimental/compiled_kernel.hpp
op/experimental/transpose.cpp op/experimental/transpose.cpp
op/experimental/transpose.hpp op/experimental/transpose.hpp
op/experimental/layers/ctc_greedy_decoder.cpp op/experimental/layers/ctc_greedy_decoder.cpp
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include "compiled_kernel.hpp" #include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
...@@ -67,8 +67,6 @@ ngraph::op::CompiledKernel::CompiledKernel(const NodeVector& node_list, ...@@ -67,8 +67,6 @@ ngraph::op::CompiledKernel::CompiledKernel(const NodeVector& node_list,
: Op("CompiledKernel", check_single_output_args({args})) : Op("CompiledKernel", check_single_output_args({args}))
, m_node_list(node_list) , m_node_list(node_list)
, m_output_nodes(outputs) , m_output_nodes(outputs)
, m_mlir_compiler(this)
, m_is_compiled(false)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
set_output_size(m_output_nodes.size()); set_output_size(m_output_nodes.size());
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
#include "ngraph/op/op.hpp" #include "ngraph/op/op.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include "contrib/mlir/compiler.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -39,34 +38,11 @@ namespace ngraph ...@@ -39,34 +38,11 @@ namespace ngraph
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
const NodeVector& get_node_list() const { return m_node_list; } const NodeVector& get_node_list() const { return m_node_list; }
const NodeVector& get_kernel_outputs() const { return m_output_nodes; } const NodeVector& get_kernel_outputs() const { return m_output_nodes; }
/// Compiles the sub-graph associated with this CompiledKernel
void compile()
{
if (m_is_compiled)
{
return;
}
m_mlir_compiler.compile();
m_is_compiled = true;
}
/// Runs the sub-graph
void run(std::vector<void*>& ptr_args)
{
NGRAPH_CHECK(m_is_compiled, "CompiledKernel node not compiled yet");
m_mlir_compiler.set_args(&ptr_args);
m_mlir_compiler.run();
}
bool is_compiled() const
{
return m_is_compiled;
}
private: private:
NodeVector m_node_list; NodeVector m_node_list;
NodeVector m_output_nodes; NodeVector m_output_nodes;
ngraph::runtime::ngmlir::MLIRCompiler m_mlir_compiler;
bool m_is_compiled;
}; };
} }
} }
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include "ngraph/runtime/cpu/cpu_builder.hpp" #include "ngraph/runtime/cpu/cpu_builder.hpp"
#include "contrib/mlir/compiler.hpp" #include "contrib/mlir/compiler.hpp"
#include "contrib/mlir/compiled_kernel.hpp" #include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/runtime/cpu/cpu_runtime_context.hpp" #include "ngraph/runtime/cpu/cpu_runtime_context.hpp"
using namespace ngraph; using namespace ngraph;
...@@ -67,8 +67,23 @@ namespace ngraph ...@@ -67,8 +67,23 @@ namespace ngraph
} }
// Compile nodes within the CompiledKernel op. // Compile nodes within the CompiledKernel op.
CompiledKernel* compiled_kernel = static_cast<CompiledKernel*>(const_cast<Node*>(node)); CompiledKernel* compiled_kernel = static_cast<CompiledKernel*>(const_cast<Node*>(node));
compiled_kernel->compile(); bool is_module_ready = true;
compiled_kernel->run(ptr_args); auto it = ctx->mlir_compilers.find(compiled_kernel);
if (it == ctx->mlir_compilers.end())
{
// create a new compiler for the CK
ctx->mlir_compilers.emplace(compiled_kernel, compiled_kernel);
is_module_ready = false;
}
MLIRCompiler& mlir_compiler = ctx->mlir_compilers.find(compiled_kernel)->second;
if (!is_module_ready)
{
mlir_compiler.compile();
}
mlir_compiler.set_args(&ptr_args);
mlir_compiler.run();
}; };
functors.emplace_back(functor); functors.emplace_back(functor);
......
...@@ -38,7 +38,7 @@ ...@@ -38,7 +38,7 @@
#include "ngraph/op/divide.hpp" #include "ngraph/op/divide.hpp"
#include "ngraph/op/equal.hpp" #include "ngraph/op/equal.hpp"
#include "ngraph/op/exp.hpp" #include "ngraph/op/exp.hpp"
#include "contrib/mlir/compiled_kernel.hpp" #include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/floor.hpp" #include "ngraph/op/floor.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp" #include "ngraph/op/greater.hpp"
......
...@@ -73,7 +73,7 @@ ...@@ -73,7 +73,7 @@
#include "ngraph/op/erf.hpp" #include "ngraph/op/erf.hpp"
#include "ngraph/op/exp.hpp" #include "ngraph/op/exp.hpp"
#include "ngraph/op/experimental/batch_mat_mul.hpp" #include "ngraph/op/experimental/batch_mat_mul.hpp"
#include "contrib/mlir/compiled_kernel.hpp" #include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/experimental/generate_mask.hpp" #include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/quantized_avg_pool.hpp" #include "ngraph/op/experimental/quantized_avg_pool.hpp"
#include "ngraph/op/experimental/quantized_concat.hpp" #include "ngraph/op/experimental/quantized_concat.hpp"
......
...@@ -25,6 +25,8 @@ ...@@ -25,6 +25,8 @@
#include <tbb/flow_graph.h> #include <tbb/flow_graph.h>
#include <tbb/global_control.h> #include <tbb/global_control.h>
#include <tbb/task_scheduler_init.h> #include <tbb/task_scheduler_init.h>
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "contrib/mlir/compiler.hpp"
namespace mkldnn namespace mkldnn
{ {
...@@ -66,6 +68,9 @@ namespace ngraph ...@@ -66,6 +68,9 @@ namespace ngraph
State* const* states; State* const* states;
std::set<size_t> breakpoints; std::set<size_t> breakpoints;
size_t pc; size_t pc;
#ifdef NGRAPH_MLIR_ENABLE
std::unordered_map<ngraph::op::CompiledKernel*, ngraph::runtime::ngmlir::MLIRCompiler> mlir_compilers;
#endif
}; };
} }
......
...@@ -52,6 +52,7 @@ ...@@ -52,6 +52,7 @@
#include "ngraph/op/erf.hpp" #include "ngraph/op/erf.hpp"
#include "ngraph/op/exp.hpp" #include "ngraph/op/exp.hpp"
#include "ngraph/op/experimental/batch_mat_mul.hpp" #include "ngraph/op/experimental/batch_mat_mul.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp" #include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_pad.hpp" #include "ngraph/op/experimental/dyn_pad.hpp"
#include "ngraph/op/experimental/dyn_replace_slice.hpp" #include "ngraph/op/experimental/dyn_replace_slice.hpp"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment