Commit 7c68d855 authored by Diego Caballero's avatar Diego Caballero Committed by nmostafa

[MLIR] Add documentation and clean up compiler.* files (#40)

parent 5b9000ac
...@@ -85,15 +85,16 @@ void MLIRCompiler::init_mlir() ...@@ -85,15 +85,16 @@ void MLIRCompiler::init_mlir()
void MLIRCompiler::compile_and_run() void MLIRCompiler::compile_and_run()
{ {
build_module(); // MLIR gen build_ng_dialect_module();
lower_dialect(); lower_ng_dialect();
optimize(); optimize();
bind_arguments(); bind_arguments();
execute(); execute();
cleanup(); cleanup();
} }
void MLIRCompiler::build_module() // Creates an MLIR module and function with nGraph dialect ops from the input CompiledKernel.
void MLIRCompiler::build_ng_dialect_module()
{ {
// initialize an empty module // initialize an empty module
m_module = make_unique<mlir::Module>(&m_context); m_module = make_unique<mlir::Module>(&m_context);
...@@ -146,6 +147,8 @@ void MLIRCompiler::build_module() ...@@ -146,6 +147,8 @@ void MLIRCompiler::build_module()
} }
} }
// Converts an nGraph Tensor into an MLIR tensor type, including the conversion of the Tensor's
// element type.
mlir::Type MLIRCompiler::get_mlir_type(const descriptor::Tensor* tensor) mlir::Type MLIRCompiler::get_mlir_type(const descriptor::Tensor* tensor)
{ {
SmallVector<int64_t, 4> shape; SmallVector<int64_t, 4> shape;
...@@ -157,6 +160,7 @@ mlir::Type MLIRCompiler::get_mlir_type(const descriptor::Tensor* tensor) ...@@ -157,6 +160,7 @@ mlir::Type MLIRCompiler::get_mlir_type(const descriptor::Tensor* tensor)
return mlir::NGTensorType::get(&m_context, get_mlir_type(tensor->get_element_type()), shape); return mlir::NGTensorType::get(&m_context, get_mlir_type(tensor->get_element_type()), shape);
} }
// Converts an nGraph element type into an MLIR type.
mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type) mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type)
{ {
switch (type.get_type_enum()) switch (type.get_type_enum())
...@@ -164,31 +168,20 @@ mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type) ...@@ -164,31 +168,20 @@ mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type)
case ngraph::element::Type_t::undefined: case ngraph::element::Type_t::undefined:
case ngraph::element::Type_t::dynamic: case ngraph::element::Type_t::dynamic:
default: NGRAPH_FAIL() << "MLIR: Unsupported NGraph types"; break; default: NGRAPH_FAIL() << "MLIR: Unsupported NGraph types"; break;
case ngraph::element::Type_t::bf16: return mlir::NGFloatType::getBF16(&m_context); case ngraph::element::Type_t::bf16: return mlir::NGFloatType::getBF16(&m_context);
case ngraph::element::Type_t::f32: return mlir::NGFloatType::getF32(&m_context); case ngraph::element::Type_t::f32: return mlir::NGFloatType::getF32(&m_context);
case ngraph::element::Type_t::f64: return mlir::NGFloatType::getF64(&m_context); case ngraph::element::Type_t::f64: return mlir::NGFloatType::getF64(&m_context);
case ngraph::element::Type_t::i8: return mlir::NGIntegerType::getInt8(&m_context); case ngraph::element::Type_t::i8: return mlir::NGIntegerType::getInt8(&m_context);
case ngraph::element::Type_t::u8: case ngraph::element::Type_t::u8:
case ngraph::element::Type_t::boolean: return mlir::NGIntegerType::getUInt8(&m_context); case ngraph::element::Type_t::boolean: return mlir::NGIntegerType::getUInt8(&m_context);
case ngraph::element::Type_t::i16: return mlir::NGIntegerType::getInt16(&m_context); case ngraph::element::Type_t::i16: return mlir::NGIntegerType::getInt16(&m_context);
case ngraph::element::Type_t::u16: return mlir::NGIntegerType::getInt16(&m_context); case ngraph::element::Type_t::u16: return mlir::NGIntegerType::getInt16(&m_context);
case ngraph::element::Type_t::i32: return mlir::NGIntegerType::getInt32(&m_context); case ngraph::element::Type_t::i32: return mlir::NGIntegerType::getInt32(&m_context);
case ngraph::element::Type_t::u32: return mlir::NGIntegerType::getUInt32(&m_context); case ngraph::element::Type_t::u32: return mlir::NGIntegerType::getUInt32(&m_context);
case ngraph::element::Type_t::i64: return mlir::NGIntegerType::getInt64(&m_context); case ngraph::element::Type_t::i64: return mlir::NGIntegerType::getInt64(&m_context);
case ngraph::element::Type_t::u64: return mlir::NGIntegerType::getUInt64(&m_context); case ngraph::element::Type_t::u64: return mlir::NGIntegerType::getUInt64(&m_context);
} }
NGRAPH_FAIL(); // Unreachable NGRAPH_FAIL() << "Unreachable";
return mlir::Type(); return mlir::Type();
} }
...@@ -209,7 +202,8 @@ MLIRCompiler::TensorInfo MLIRCompiler::get_tensor_value(descriptor::Tensor* tens ...@@ -209,7 +202,8 @@ MLIRCompiler::TensorInfo MLIRCompiler::get_tensor_value(descriptor::Tensor* tens
return it->second; return it->second;
} }
void MLIRCompiler::lower_dialect() // Lowers nGraph dialect to affine dialect.
void MLIRCompiler::lower_ng_dialect()
{ {
mlir::PassManager pm; mlir::PassManager pm;
pm.addPass(mlir::createDialectLoweringPass(this)); pm.addPass(mlir::createDialectLoweringPass(this));
...@@ -227,14 +221,16 @@ void MLIRCompiler::lower_dialect() ...@@ -227,14 +221,16 @@ void MLIRCompiler::lower_dialect()
} }
} }
// Receives affine dialect as input and applies affine and standard dialect based optimizations.
// Lowering from affine dialect to standard dialect happens along the way. Output consists of
// standard dialect only ops.
void MLIRCompiler::optimize() void MLIRCompiler::optimize()
{ {
mlir::PassManager pm; mlir::PassManager pm;
// Lower affine ops // Lower affine ops
pm.addPass(mlir::createLowerAffinePass()); pm.addPass(mlir::createLowerAffinePass());
auto rr = pm.run(m_module.get()); auto rr = pm.run(m_module.get());
(void)rr; NGRAPH_ASSERT(succeeded(rr)) << "Affine loop lowering failed";
assert(succeeded(rr) && "affine loop lowering failed");
} }
// MLIR builders // MLIR builders
...@@ -271,7 +267,6 @@ mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Add) ...@@ -271,7 +267,6 @@ mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Add)
template <> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Dot) mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Dot)
{ {
NGRAPH_ASSERT(ng_node->get_arguments().size() == 2) << "Expected two operands in Dot operation";
return compiler.create_binary_op<mlir::NGDotOp>(ng_node); return compiler.create_binary_op<mlir::NGDotOp>(ng_node);
} }
...@@ -302,6 +297,8 @@ void MLIRCompiler::create_return() ...@@ -302,6 +297,8 @@ void MLIRCompiler::create_return()
m_builder->create<mlir::NGReturnOp>(mlir::UnknownLoc::get(&m_context), value_list); m_builder->create<mlir::NGReturnOp>(mlir::UnknownLoc::get(&m_context), value_list);
} }
// Binds MLIR function arguments to the proper values. This includes externally allocated tensors
// helpers to be used inside the function.
void MLIRCompiler::bind_arguments() void MLIRCompiler::bind_arguments()
{ {
NGRAPH_ASSERT(m_module && "MLIR module is not ready."); NGRAPH_ASSERT(m_module && "MLIR module is not ready.");
...@@ -338,6 +335,7 @@ void MLIRCompiler::bind_arguments() ...@@ -338,6 +335,7 @@ void MLIRCompiler::bind_arguments()
m_invoke_args.push_back(static_cast<void*>(mem_mgr_arg)); m_invoke_args.push_back(static_cast<void*>(mem_mgr_arg));
} }
// Lowers standard dialect to LLVM dialect and uses the MLIR execution engine to execute the code.
void MLIRCompiler::execute() void MLIRCompiler::execute()
{ {
NGRAPH_ASSERT(m_module && "MLIR module is not ready."); NGRAPH_ASSERT(m_module && "MLIR module is not ready.");
......
...@@ -47,9 +47,13 @@ namespace ngraph ...@@ -47,9 +47,13 @@ namespace ngraph
{ {
namespace ngmlir namespace ngmlir
{ {
/// This class is the entry point to MLIR from nGraph. It drives the conversion of
/// nGraph sub-graphs, represented with CompiledKernel nodes, to MLIR nGraph dialect
/// and its lowering, optimization and execution using LLVM-based MLIR execution engine.
class MLIRCompiler class MLIRCompiler
{ {
public: public:
/// Initializes MLIR environment. It must be called only once per execution.
static void init_mlir(); static void init_mlir();
public: public:
...@@ -59,12 +63,13 @@ namespace ngraph ...@@ -59,12 +63,13 @@ namespace ngraph
MLIRCompiler(const ngraph::op::CompiledKernel* compiled_kernel, MLIRCompiler(const ngraph::op::CompiledKernel* compiled_kernel,
const std::vector<void*>& external_tensors); const std::vector<void*>& external_tensors);
/// Compiles and runs a subgraph in MLIR /// Compiles and runs a subgraph in MLIR.
void compile_and_run(); void compile_and_run();
/// Returns the memory manager used by this sub-graph compiler /// Returns the memory manager used by this sub-graph compiler.
MLIRMemMgr& get_mem_mgr() { return m_mem_mgr; } MLIRMemMgr& get_mem_mgr() { return m_mem_mgr; }
/// Returns memory manager pointer argument ID in call interface
/// Returns memory manager pointer argument ID in call interface.
unsigned get_mem_mgr_arg_id(mlir::Function* func) unsigned get_mem_mgr_arg_id(mlir::Function* func)
{ {
return func->getNumArguments() - 1; return func->getNumArguments() - 1;
...@@ -73,13 +78,13 @@ namespace ngraph ...@@ -73,13 +78,13 @@ namespace ngraph
private: private:
struct TensorInfo struct TensorInfo
{ {
mlir::Value* m_value; /* mlir value this tensor maps to */ // MLIR values this tensor maps to.
// More info here ? mlir::Value* m_value;
}; };
private: private:
void build_module(); void build_ng_dialect_module();
void lower_dialect(); void lower_ng_dialect();
void optimize(); void optimize();
void bind_arguments(); void bind_arguments();
void execute(); void execute();
...@@ -111,7 +116,19 @@ namespace ngraph ...@@ -111,7 +116,19 @@ namespace ngraph
mlir::StaticFloatMemRef* allocate_memref_descriptor(mlir::Type type); mlir::StaticFloatMemRef* allocate_memref_descriptor(mlir::Type type);
private: private:
// Sub-graph to be compiled and executed with MLIR.
const ngraph::op::CompiledKernel* m_compiled_kernel;
// Pointers to externally allocated memory for sub-graph's input and output tensors.
const std::vector<void*>& m_external_tensors;
// Arguments for the MLIR function generated for the nGraph sub-graph.
llvm::SmallVector<void*, 8> m_invoke_args;
// MLIR context that holds all the MLIR information related to the sub-graph
// compilation.
mlir::MLIRContext m_context; mlir::MLIRContext m_context;
std::unique_ptr<mlir::Module> m_module; std::unique_ptr<mlir::Module> m_module;
std::unique_ptr<mlir::FuncBuilder> m_builder; std::unique_ptr<mlir::FuncBuilder> m_builder;
std::unique_ptr<mlir::ExecutionEngine> m_engine; std::unique_ptr<mlir::ExecutionEngine> m_engine;
...@@ -122,13 +139,6 @@ namespace ngraph ...@@ -122,13 +139,6 @@ namespace ngraph
std::function<mlir::Value*(MLIRCompiler& compiler, const ngraph::Node*)>; std::function<mlir::Value*(MLIRCompiler& compiler, const ngraph::Node*)>;
using MLIRCompOpMap = std::unordered_map<std::type_index, MLIRCompOpFunction>; using MLIRCompOpMap = std::unordered_map<std::type_index, MLIRCompOpFunction>;
// Sub-graph to be compiled and executed with MLIR.
const ngraph::op::CompiledKernel* m_compiled_kernel;
// Pointers to externally allocated memory for sub-graph's input and output tensors.
const std::vector<void*>& m_external_tensors;
llvm::SmallVector<void*, 8> m_invoke_args;
// Maps tensor to the value it represents in the IR // Maps tensor to the value it represents in the IR
// use for MLIR dialect gen // use for MLIR dialect gen
TensorToInfoMap m_tensor_to_value_map; TensorToInfoMap m_tensor_to_value_map;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment