Commit 75a5cb00 authored by nmostafa's avatar nmostafa

Review fixes

parent f4954d57
......@@ -87,24 +87,15 @@ void MLIRCompiler::init_mlir()
}
}
void MLIRCompiler::set_args(std::vector<void*>* external_tensors)
{
NGRAPH_CHECK(m_compiled_kernel, "No compiled kernel set for compiler");
NGRAPH_CHECK((m_compiled_kernel->get_arguments().size() +
m_compiled_kernel->get_kernel_outputs().size()) == external_tensors->size(),
"Number of arguments and outputs doesn't match number of tensors");
m_external_tensors = external_tensors;
}
void MLIRCompiler::compile()
{
build_ng_dialect_module();
lower_ng_dialect();
}
void MLIRCompiler::run()
void MLIRCompiler::run(std::vector<void*>& external_tensors)
{
bind_arguments();
bind_arguments(external_tensors);
execute();
cleanup();
}
......@@ -487,13 +478,20 @@ mlir::Operation* MLIRCompiler::create_index_reduction(const ngraph::Node* ng_nod
}
// Binds MLIR function arguments to the proper values. This includes externally allocated tensors
// helpers to be used inside the function.
void MLIRCompiler::bind_arguments()
void MLIRCompiler::bind_arguments(std::vector<void*>& external_tensors)
{
NGRAPH_CHECK(m_module, "MLIR module is not ready.");
mlir::Function* func = m_module->getNamedFunction("main");
NGRAPH_CHECK(func && !func->getBlocks().empty(), "Function not found");
// Set external arguments
NGRAPH_CHECK(m_compiled_kernel, "No compiled kernel set for compiler");
NGRAPH_CHECK((m_compiled_kernel->get_arguments().size() +
m_compiled_kernel->get_kernel_outputs().size()) == external_tensors.size(),
"Number of arguments and outputs doesn't match number of tensors");
m_external_tensors = &external_tensors;
// Create list with a type-erased double pointer for each invocation arguments.
// We currently use 'allocateMemRefArguments', which creates a
// SmallVector<StaticFloatMemref*>. StaticFloatMemref is just a struct with the
......
......@@ -65,12 +65,10 @@ namespace ngraph
{
}
/// Set runtime tensor arguments for the sub-graph
void set_args(std::vector<void*>* external_tensors);
/// Compiles a subgraph with MLIR
void compile();
/// Executes a pre-compiled subgraph
void run();
void run(std::vector<void*>& external_tensors);
/// Returns the memory manager used by this sub-graph compiler.
MLIRMemMgr& get_mem_mgr() { return m_mem_mgr; }
......@@ -90,7 +88,7 @@ namespace ngraph
private:
void build_ng_dialect_module();
void lower_ng_dialect();
void bind_arguments();
void bind_arguments(std::vector<void*>& external_tensors);
void execute();
void cleanup();
......
......@@ -83,8 +83,7 @@ namespace ngraph
{
mlir_compiler.compile();
}
mlir_compiler.set_args(&ptr_args);
mlir_compiler.run();
mlir_compiler.run(ptr_args);
};
functors.emplace_back(functor);
......
......@@ -72,6 +72,9 @@ namespace ngraph
std::set<size_t> breakpoints;
size_t pc;
#ifdef NGRAPH_MLIR_ENABLE
/// Maps CompiledKernel nodes to their MLIR compiler
/// The MLIR compiler caches the compiled code on the first invocation,
/// and may in the future support re-compilation
std::unordered_map<ngraph::op::CompiledKernel*,
ngraph::runtime::ngmlir::MLIRCompiler>
mlir_compilers;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment