Commit 7c68d855 authored by Diego Caballero's avatar Diego Caballero Committed by nmostafa

[MLIR] Add documentation and clean up compiler.* files (#40)

parent 5b9000ac
......@@ -85,15 +85,16 @@ void MLIRCompiler::init_mlir()
void MLIRCompiler::compile_and_run()
{
build_module(); // MLIR gen
lower_dialect();
build_ng_dialect_module();
lower_ng_dialect();
optimize();
bind_arguments();
execute();
cleanup();
}
void MLIRCompiler::build_module()
// Creates an MLIR module and function with nGraph dialect ops from the input CompiledKernel.
void MLIRCompiler::build_ng_dialect_module()
{
// initialize an empty module
m_module = make_unique<mlir::Module>(&m_context);
......@@ -146,6 +147,8 @@ void MLIRCompiler::build_module()
}
}
// Converts an nGraph Tensor into an MLIR tensor type, including the conversion of the Tensor's
// element type.
mlir::Type MLIRCompiler::get_mlir_type(const descriptor::Tensor* tensor)
{
SmallVector<int64_t, 4> shape;
......@@ -157,6 +160,7 @@ mlir::Type MLIRCompiler::get_mlir_type(const descriptor::Tensor* tensor)
return mlir::NGTensorType::get(&m_context, get_mlir_type(tensor->get_element_type()), shape);
}
// Converts an nGraph element type into an MLIR type.
mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type)
{
switch (type.get_type_enum())
......@@ -164,31 +168,20 @@ mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type)
case ngraph::element::Type_t::undefined:
case ngraph::element::Type_t::dynamic:
default: NGRAPH_FAIL() << "MLIR: Unsupported NGraph types"; break;
case ngraph::element::Type_t::bf16: return mlir::NGFloatType::getBF16(&m_context);
case ngraph::element::Type_t::f32: return mlir::NGFloatType::getF32(&m_context);
case ngraph::element::Type_t::f64: return mlir::NGFloatType::getF64(&m_context);
case ngraph::element::Type_t::i8: return mlir::NGIntegerType::getInt8(&m_context);
case ngraph::element::Type_t::u8:
case ngraph::element::Type_t::boolean: return mlir::NGIntegerType::getUInt8(&m_context);
case ngraph::element::Type_t::i16: return mlir::NGIntegerType::getInt16(&m_context);
case ngraph::element::Type_t::u16: return mlir::NGIntegerType::getInt16(&m_context);
case ngraph::element::Type_t::i32: return mlir::NGIntegerType::getInt32(&m_context);
case ngraph::element::Type_t::u32: return mlir::NGIntegerType::getUInt32(&m_context);
case ngraph::element::Type_t::i64: return mlir::NGIntegerType::getInt64(&m_context);
case ngraph::element::Type_t::u64: return mlir::NGIntegerType::getUInt64(&m_context);
}
NGRAPH_FAIL(); // Unreachable
NGRAPH_FAIL() << "Unreachable";
return mlir::Type();
}
......@@ -209,7 +202,8 @@ MLIRCompiler::TensorInfo MLIRCompiler::get_tensor_value(descriptor::Tensor* tens
return it->second;
}
void MLIRCompiler::lower_dialect()
// Lowers nGraph dialect to affine dialect.
void MLIRCompiler::lower_ng_dialect()
{
mlir::PassManager pm;
pm.addPass(mlir::createDialectLoweringPass(this));
......@@ -227,14 +221,16 @@ void MLIRCompiler::lower_dialect()
}
}
// Receives affine dialect as input and applies affine and standard dialect based optimizations.
// Lowering from affine dialect to standard dialect happens along the way. Output consists of
// standard dialect only ops.
void MLIRCompiler::optimize()
{
mlir::PassManager pm;
// Lower affine ops
pm.addPass(mlir::createLowerAffinePass());
auto rr = pm.run(m_module.get());
(void)rr;
assert(succeeded(rr) && "affine loop lowering failed");
NGRAPH_ASSERT(succeeded(rr)) << "Affine loop lowering failed";
}
// MLIR builders
......@@ -271,7 +267,6 @@ mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Add)
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Dot)
{
NGRAPH_ASSERT(ng_node->get_arguments().size() == 2) << "Expected two operands in Dot operation";
return compiler.create_binary_op<mlir::NGDotOp>(ng_node);
}
......@@ -302,6 +297,8 @@ void MLIRCompiler::create_return()
m_builder->create<mlir::NGReturnOp>(mlir::UnknownLoc::get(&m_context), value_list);
}
// Binds MLIR function arguments to the proper values. This includes externally allocated tensors
// helpers to be used inside the function.
void MLIRCompiler::bind_arguments()
{
NGRAPH_ASSERT(m_module && "MLIR module is not ready.");
......@@ -338,6 +335,7 @@ void MLIRCompiler::bind_arguments()
m_invoke_args.push_back(static_cast<void*>(mem_mgr_arg));
}
// Lowers standard dialect to LLVM dialect and uses the MLIR execution engine to execute the code.
void MLIRCompiler::execute()
{
NGRAPH_ASSERT(m_module && "MLIR module is not ready.");
......
......@@ -47,9 +47,13 @@ namespace ngraph
{
namespace ngmlir
{
/// This class is the entry point to MLIR from nGraph. It drives the conversion of
/// nGraph sub-graphs, represented with CompiledKernel nodes, to MLIR nGraph dialect
/// and its lowering, optimization and execution using LLVM-based MLIR execution engine.
class MLIRCompiler
{
public:
/// Initializes MLIR environment. It must be called only once per execution.
static void init_mlir();
public:
......@@ -59,12 +63,13 @@ namespace ngraph
MLIRCompiler(const ngraph::op::CompiledKernel* compiled_kernel,
const std::vector<void*>& external_tensors);
/// Compiles and runs a subgraph in MLIR
/// Compiles and runs a subgraph in MLIR.
void compile_and_run();
/// Returns the memory manager used by this sub-graph compiler
/// Returns the memory manager used by this sub-graph compiler.
MLIRMemMgr& get_mem_mgr() { return m_mem_mgr; }
/// Returns memory manager pointer argument ID in call interface
/// Returns memory manager pointer argument ID in call interface.
unsigned get_mem_mgr_arg_id(mlir::Function* func)
{
return func->getNumArguments() - 1;
......@@ -73,13 +78,13 @@ namespace ngraph
private:
struct TensorInfo
{
mlir::Value* m_value; /* mlir value this tensor maps to */
// More info here ?
// MLIR values this tensor maps to.
mlir::Value* m_value;
};
private:
void build_module();
void lower_dialect();
void build_ng_dialect_module();
void lower_ng_dialect();
void optimize();
void bind_arguments();
void execute();
......@@ -111,7 +116,19 @@ namespace ngraph
mlir::StaticFloatMemRef* allocate_memref_descriptor(mlir::Type type);
private:
// Sub-graph to be compiled and executed with MLIR.
const ngraph::op::CompiledKernel* m_compiled_kernel;
// Pointers to externally allocated memory for sub-graph's input and output tensors.
const std::vector<void*>& m_external_tensors;
// Arguments for the MLIR function generated for the nGraph sub-graph.
llvm::SmallVector<void*, 8> m_invoke_args;
// MLIR context that holds all the MLIR information related to the sub-graph
// compilation.
mlir::MLIRContext m_context;
std::unique_ptr<mlir::Module> m_module;
std::unique_ptr<mlir::FuncBuilder> m_builder;
std::unique_ptr<mlir::ExecutionEngine> m_engine;
......@@ -122,13 +139,6 @@ namespace ngraph
std::function<mlir::Value*(MLIRCompiler& compiler, const ngraph::Node*)>;
using MLIRCompOpMap = std::unordered_map<std::type_index, MLIRCompOpFunction>;
// Sub-graph to be compiled and executed with MLIR.
const ngraph::op::CompiledKernel* m_compiled_kernel;
// Pointers to externally allocated memory for sub-graph's input and output tensors.
const std::vector<void*>& m_external_tensors;
llvm::SmallVector<void*, 8> m_invoke_args;
// Maps tensor to the value it represents in the IR
// use for MLIR dialect gen
TensorToInfoMap m_tensor_to_value_map;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment