Commit 86bc31cc authored by Nagy Mostafa's avatar Nagy Mostafa Committed by nmostafa

[MLIR] Move mlir related classes to MLIR namespace (#23)

* Move dialect and types to mlir namespace

* PR fixes and some cleanup

* Merge fix
parent ea441a6e
...@@ -51,357 +51,353 @@ using namespace ngraph::runtime::ngmlir; ...@@ -51,357 +51,353 @@ using namespace ngraph::runtime::ngmlir;
#define COMPILE_OP_DECL(op_name) \ #define COMPILE_OP_DECL(op_name) \
create_op<op_name>(MLIRCompiler & compiler, const ngraph::Node* ng_node) create_op<op_name>(MLIRCompiler & compiler, const ngraph::Node* ng_node)
namespace ngraph MLIRCompiler::MLIRCompiler(const ngraph::op::CompiledKernel* compiled_kernel,
const std::vector<void*>& external_tensors)
: m_compiled_kernel(compiled_kernel)
, m_external_tensors(external_tensors)
{ {
MLIRCompiler::MLIRCompiler(const ngraph::op::CompiledKernel* compiled_kernel, NGRAPH_ASSERT((m_compiled_kernel->get_arguments().size() +
const std::vector<void*>& external_tensors) m_compiled_kernel->get_kernel_outputs().size()) == external_tensors.size())
: m_compiled_kernel(compiled_kernel) << "Number of arguments and outputs doesn't match number of tensors";
, m_external_tensors(external_tensors) }
void MLIRCompiler::init_mlir()
{
mlir::registerDialect<mlir::NGDialect>();
// Register any LLVM command line options
llvm::cl::ParseEnvironmentOptions("ngraph", "MLIR_LLVM_OPTIONS", "");
}
void MLIRCompiler::compile_and_run()
{
build_module(); // MLIR gen
lower_dialect();
optimize();
bind_arguments();
execute();
cleanup();
}
void MLIRCompiler::build_module()
{
// initialize an empty module
m_module = make_unique<mlir::Module>(&m_context);
TypeList args_type_list, result_type_list;
// Retrieve input and output tensors.
const auto& kernel_inputs = m_compiled_kernel->get_arguments();
const auto& kernel_outputs = m_compiled_kernel->get_kernel_outputs();
NGRAPH_ASSERT(kernel_inputs.size() != 0) << "Cannot have empty inputs list";
NGRAPH_ASSERT(kernel_outputs.size() != 0) << "Cannot have empty outputs list";
for (auto input : kernel_inputs)
{ {
NGRAPH_ASSERT((m_compiled_kernel->get_arguments().size() + args_type_list.push_back(get_mlir_type(input->get_output_tensor_ptr().get()));
m_compiled_kernel->get_kernel_outputs().size()) == external_tensors.size())
<< "Number of arguments and outputs doesn't match number of tensors";
} }
void MLIRCompiler::init_mlir() for (auto output : kernel_outputs)
{ {
mlir::registerDialect<NGDialect>(); result_type_list.push_back(get_mlir_type(output->get_output_tensor_ptr().get()));
// Register any LLVM command line options
llvm::cl::ParseEnvironmentOptions("ngraph", "MLIR_LLVM_OPTIONS", "");
} }
void MLIRCompiler::compile_and_run() auto func_type = mlir::FunctionType::get(args_type_list, result_type_list, &m_context);
auto function =
make_unique<mlir::Function>(mlir::UnknownLoc::get(&m_context), "main", func_type);
function->addEntryBlock();
// populate Tensor->Value maps
int i = 0;
for (auto input : kernel_inputs)
{ {
build_module(); // MLIR gen mlir::Value* arg = function->getArgument(i);
lower_dialect(); TensorInfo tensor_info{arg};
optimize(); m_tensor_to_value_map.insert(
bind_arguments(); TensorToInfo(input->get_output_tensor_ptr().get(), tensor_info));
execute(); i++;
cleanup();
} }
void MLIRCompiler::build_module() // create builder
m_builder = llvm::make_unique<mlir::FuncBuilder>(function.get());
build_ng_dialect();
m_module->getFunctions().push_back(function.release());
if (failed(m_module->verify()))
{ {
// initialize an empty module NGRAPH_FAIL() << "Invalid module after lowering to NG dialect";
m_module = make_unique<mlir::Module>(&m_context);
TypeList args_type_list, result_type_list;
// Retrieve input and output tensors.
const auto& kernel_inputs = m_compiled_kernel->get_arguments();
const auto& kernel_outputs = m_compiled_kernel->get_kernel_outputs();
NGRAPH_ASSERT(kernel_inputs.size() != 0) << "Cannot have empty inputs list";
NGRAPH_ASSERT(kernel_outputs.size() != 0) << "Cannot have empty outputs list";
for (auto input : kernel_inputs)
{
args_type_list.push_back(get_mlir_type(input->get_output_tensor_ptr().get()));
}
for (auto output : kernel_outputs)
{
result_type_list.push_back(get_mlir_type(output->get_output_tensor_ptr().get()));
}
auto func_type = mlir::FunctionType::get(args_type_list, result_type_list, &m_context);
auto function =
make_unique<mlir::Function>(mlir::UnknownLoc::get(&m_context), "main", func_type);
function->addEntryBlock();
// populate Tensor->Value maps
int i = 0;
for (auto input : kernel_inputs)
{
mlir::Value* arg = function->getArgument(i);
TensorInfo tensor_info{arg};
m_tensor_to_value_map.insert(
TensorToInfo(input->get_output_tensor_ptr().get(), tensor_info));
i++;
}
// create builder
m_builder = llvm::make_unique<mlir::FuncBuilder>(function.get());
build_ng_dialect();
m_module->getFunctions().push_back(function.release());
if (failed(m_module->verify()))
{
NGRAPH_FAIL() << "Invalid module after lowering to NG dialect";
}
if (std::getenv("NGRAPH_MLIR_DUMP_ALL") != nullptr)
{
m_module->dump();
}
} }
if (std::getenv("NGRAPH_MLIR_DUMP_ALL") != nullptr)
mlir::Type MLIRCompiler::get_mlir_type(const descriptor::Tensor* tensor)
{ {
SmallVector<int64_t, 4> shape; m_module->dump();
for (auto d : tensor->get_shape()) }
{ }
shape.push_back(d);
}
return NGTensorType::get(&m_context, get_mlir_type(tensor->get_element_type()), shape); mlir::Type MLIRCompiler::get_mlir_type(const descriptor::Tensor* tensor)
{
SmallVector<int64_t, 4> shape;
for (auto d : tensor->get_shape())
{
shape.push_back(d);
} }
mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type) return mlir::NGTensorType::get(&m_context, get_mlir_type(tensor->get_element_type()), shape);
}
mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type)
{
switch (type.get_type_enum())
{ {
switch (type.get_type_enum()) case ngraph::element::Type_t::undefined:
{ case ngraph::element::Type_t::dynamic:
case ngraph::element::Type_t::undefined: default: NGRAPH_FAIL() << "MLIR: Unsupported NGraph types"; break;
case ngraph::element::Type_t::dynamic:
default: NGRAPH_FAIL() << "MLIR: Unsupported NGraph types"; break;
case ngraph::element::Type_t::bf16: return NGFloatType::getBF16(&m_context); case ngraph::element::Type_t::bf16: return mlir::NGFloatType::getBF16(&m_context);
case ngraph::element::Type_t::f32: return NGFloatType::getF32(&m_context); case ngraph::element::Type_t::f32: return mlir::NGFloatType::getF32(&m_context);
case ngraph::element::Type_t::f64: return NGFloatType::getF64(&m_context); case ngraph::element::Type_t::f64: return mlir::NGFloatType::getF64(&m_context);
case ngraph::element::Type_t::i8: return NGIntegerType::getInt8(&m_context); case ngraph::element::Type_t::i8: return mlir::NGIntegerType::getInt8(&m_context);
case ngraph::element::Type_t::u8: case ngraph::element::Type_t::u8:
case ngraph::element::Type_t::boolean: return NGIntegerType::getUInt8(&m_context); case ngraph::element::Type_t::boolean: return mlir::NGIntegerType::getUInt8(&m_context);
case ngraph::element::Type_t::i16: return NGIntegerType::getInt16(&m_context); case ngraph::element::Type_t::i16: return mlir::NGIntegerType::getInt16(&m_context);
case ngraph::element::Type_t::u16: return NGIntegerType::getInt16(&m_context); case ngraph::element::Type_t::u16: return mlir::NGIntegerType::getInt16(&m_context);
case ngraph::element::Type_t::i32: return NGIntegerType::getInt32(&m_context); case ngraph::element::Type_t::i32: return mlir::NGIntegerType::getInt32(&m_context);
case ngraph::element::Type_t::u32: return NGIntegerType::getUInt32(&m_context); case ngraph::element::Type_t::u32: return mlir::NGIntegerType::getUInt32(&m_context);
case ngraph::element::Type_t::i64: return NGIntegerType::getInt64(&m_context); case ngraph::element::Type_t::i64: return mlir::NGIntegerType::getInt64(&m_context);
case ngraph::element::Type_t::u64: return NGIntegerType::getUInt64(&m_context); case ngraph::element::Type_t::u64: return mlir::NGIntegerType::getUInt64(&m_context);
}
NGRAPH_FAIL(); // Unreachable
return mlir::Type();
} }
NGRAPH_FAIL(); // Unreachable
return mlir::Type();
}
void MLIRCompiler::update_tensor_value(descriptor::Tensor* tensor, mlir::Value* value) void MLIRCompiler::update_tensor_value(descriptor::Tensor* tensor, mlir::Value* value)
{ {
NGRAPH_ASSERT(m_tensor_to_value_map.find(tensor) == m_tensor_to_value_map.end()) NGRAPH_ASSERT(m_tensor_to_value_map.find(tensor) == m_tensor_to_value_map.end())
<< "tensor value already defined"; << "tensor value already defined";
TensorInfo tensor_info{value}; TensorInfo tensor_info{value};
m_tensor_to_value_map.insert(TensorToInfo(tensor, tensor_info)); m_tensor_to_value_map.insert(TensorToInfo(tensor, tensor_info));
} }
MLIRCompiler::TensorInfo MLIRCompiler::get_tensor_value(descriptor::Tensor* tensor) MLIRCompiler::TensorInfo MLIRCompiler::get_tensor_value(descriptor::Tensor* tensor)
{ {
auto it = m_tensor_to_value_map.find(tensor); auto it = m_tensor_to_value_map.find(tensor);
NGRAPH_ASSERT(it != m_tensor_to_value_map.end()) << "Undefined tensor"; NGRAPH_ASSERT(it != m_tensor_to_value_map.end()) << "Undefined tensor";
return it->second; return it->second;
} }
void MLIRCompiler::lower_dialect() void MLIRCompiler::lower_dialect()
{
mlir::PassManager pm;
pm.addPass(mlir::createDialectLoweringPass(this));
pm.addPass(mlir::createCanonicalizerPass());
pm.run(m_module.get());
if (failed(m_module->verify()))
{ {
mlir::PassManager pm; NGRAPH_FAIL() << "Incorrect module after dialect lowering";
pm.addPass(createDialectLoweringPass(this));
pm.addPass(mlir::createCanonicalizerPass());
pm.run(m_module.get());
if (failed(m_module->verify()))
{
NGRAPH_FAIL() << "Incorrect module after dialect lowering";
}
if (std::getenv("NGRAPH_MLIR_DUMP_ALL") != nullptr)
{
m_module->dump();
}
} }
if (std::getenv("NGRAPH_MLIR_DUMP_ALL") != nullptr)
void MLIRCompiler::optimize()
{ {
mlir::PassManager pm; m_module->dump();
// Lower affine ops
pm.addPass(mlir::createLowerAffinePass());
auto rr = pm.run(m_module.get());
(void)rr;
assert(succeeded(rr) && "affine loop lowering failed");
} }
}
void MLIRCompiler::optimize()
{
mlir::PassManager pm;
// Lower affine ops
pm.addPass(mlir::createLowerAffinePass());
auto rr = pm.run(m_module.get());
(void)rr;
assert(succeeded(rr) && "affine loop lowering failed");
}
// MLIR builders // MLIR builders
#define TI(x) std::type_index(typeid(x)) #define TI(x) std::type_index(typeid(x))
void MLIRCompiler::build_ng_dialect() void MLIRCompiler::build_ng_dialect()
{
const NodeVector& sub_graph = m_compiled_kernel->get_node_list();
NGRAPH_ASSERT(sub_graph.size() == 1) << "Supporting code-gen for a single node for now";
auto np = sub_graph[0];
auto it = op_dispatcher.find(TI(*np));
if (it == op_dispatcher.end())
{ {
const NodeVector& sub_graph = m_compiled_kernel->get_node_list(); throw unsupported_op{std::string{"The MLIR backend doesn't currently implement the '"} +
NGRAPH_ASSERT(sub_graph.size() == 1) << "Supporting code-gen for a single node for now"; np->description() + "' operation"};
auto np = sub_graph[0];
auto it = op_dispatcher.find(TI(*np));
if (it == op_dispatcher.end())
{
throw unsupported_op{std::string{"The MLIR backend doesn't currently implement the '"} +
np->description() + "' operation"};
}
mlir::Value* mlir_value = it->second(*this, np.get());
// builders that have multiple result values will update the value map, and set their ret values to null
if (mlir_value)
{
update_tensor_value(np->get_output_tensor_ptr().get(), mlir_value);
}
create_return();
} }
mlir::Value* mlir_value = it->second(*this, np.get());
template <> // builders that have multiple result values will update the value map, and set their ret values to null
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Add) if (mlir_value)
{ {
return compiler.create_binary_op<NGAddOp>(ng_node); update_tensor_value(np->get_output_tensor_ptr().get(), mlir_value);
} }
template <> create_return();
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::MatmulBias) }
{
// TODO(dcab): Implement all the variants of a Matmul/MatmulBias op.
// Keeping it simple for now.
NGRAPH_ASSERT(ng_node->get_arguments().size() == 2)
<< "Bias is not supported in MatmulBias operation";
return compiler.create_binary_op<NGMatMulBiasOp>(ng_node); template <>
} mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Add)
{
return compiler.create_binary_op<mlir::NGAddOp>(ng_node);
}
const MLIRCompiler::MLIRCompOpMap MLIRCompiler::op_dispatcher{ template <>
{TI(ngraph::op::Add), &MLIRCompiler::create_op<ngraph::op::Add>}, mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::MatmulBias)
{TI(ngraph::op::MatmulBias), &MLIRCompiler::create_op<ngraph::op::MatmulBias>}}; {
// TODO(dcab): Implement all the variants of a Matmul/MatmulBias op.
// Keeping it simple for now.
NGRAPH_ASSERT(ng_node->get_arguments().size() == 2)
<< "Bias is not supported in MatmulBias operation";
template <typename BinOp> return compiler.create_binary_op<mlir::NGMatMulBiasOp>(ng_node);
mlir::Value* MLIRCompiler::create_binary_op(const ngraph::Node* ng_node) }
{
auto lhs = ng_node->get_argument(0)->get_output_tensor_ptr();
auto rhs = ng_node->get_argument(1)->get_output_tensor_ptr();
auto lhs_v = get_tensor_value(lhs.get()).m_value;
auto rhs_v = get_tensor_value(rhs.get()).m_value;
return m_builder->create<BinOp>(mlir::UnknownLoc::get(&m_context), lhs_v, rhs_v)
.getResult();
}
void MLIRCompiler::create_return() const MLIRCompiler::MLIRCompOpMap MLIRCompiler::op_dispatcher{
{ {TI(ngraph::op::Add), &MLIRCompiler::create_op<ngraph::op::Add>},
std::vector<mlir::Value*> value_list; {TI(ngraph::op::MatmulBias), &MLIRCompiler::create_op<ngraph::op::MatmulBias>}};
for (auto output : m_compiled_kernel->get_kernel_outputs())
{
value_list.push_back(get_tensor_value(output->get_output_tensor_ptr().get()).m_value);
}
m_builder->create<NGReturnOp>(mlir::UnknownLoc::get(&m_context), value_list);
}
void MLIRCompiler::bind_arguments() template <typename BinOp>
{ mlir::Value* MLIRCompiler::create_binary_op(const ngraph::Node* ng_node)
NGRAPH_ASSERT(m_module && "MLIR module is not ready."); {
auto lhs = ng_node->get_argument(0)->get_output_tensor_ptr();
mlir::Function* func = m_module->getNamedFunction("main"); auto rhs = ng_node->get_argument(1)->get_output_tensor_ptr();
NGRAPH_ASSERT(func && !func->getBlocks().empty()) << "Function not found"; auto lhs_v = get_tensor_value(lhs.get()).m_value;
auto rhs_v = get_tensor_value(rhs.get()).m_value;
// Create list with a type-erased double pointer for each invocation arguments. return m_builder->create<BinOp>(mlir::UnknownLoc::get(&m_context), lhs_v, rhs_v).getResult();
// We currently use 'allocateMemRefArguments', which creates a }
// SmallVector<StaticFloatMemref*>. StaticFloatMemref is just a struct with the
// actual pointer to the data.
// create MemRef args
auto expected_arguments = allocate_memref_args(func);
NGRAPH_ASSERT(expected_arguments.size()) << "Arguments can't be created";
m_invoke_args = std::move(expected_arguments);
NGRAPH_ASSERT(m_invoke_args.size() == m_external_tensors.size())
<< "Number of external tensors doesn't match number of function arguments";
// Assign external tensor pointers to invocation arguments.
for (size_t i = 0, num_args = m_invoke_args.size(); i < num_args; ++i)
{
((mlir::StaticFloatMemRef*)m_invoke_args[i])->data = (float*)m_external_tensors[i];
}
// Add pointer to memory manager
// malloc here since that's what allocateMemRefArguments use
// TODO (nmostafa): Better way of doing this ? Use builder allocator ?
MLIRMemMgr** mem_mgr_arg = reinterpret_cast<MLIRMemMgr**>(malloc(sizeof(void*)));
*mem_mgr_arg = &get_mem_mgr();
// inserting memory manager ptr in right location ?
NGRAPH_ASSERT(m_invoke_args.size() == get_mem_mgr_arg_id(func));
m_invoke_args.push_back(static_cast<void*>(mem_mgr_arg));
}
void MLIRCompiler::execute() void MLIRCompiler::create_return()
{
std::vector<mlir::Value*> value_list;
for (auto output : m_compiled_kernel->get_kernel_outputs())
{ {
NGRAPH_ASSERT(m_module && "MLIR module is not ready."); value_list.push_back(get_tensor_value(output->get_output_tensor_ptr().get()).m_value);
// Lower Standard dialect to LLVM dialect.
auto converter = mlir::createStdToLLVMConverter();
auto r = converter->convert(m_module.get());
(void)r;
NGRAPH_ASSERT(succeeded(r)) << "second conversion failed";
// Initialize LLVM targets.
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
// Create an MLIR execution engine. We use a null MLIR pass manager for now to make sure we
// don't run MLIR passes that were already run. We also pass a default transformer to run
// LLVM optimizations at level 3.
mlir::PassManager* mlir_pm = nullptr;
auto llvm_transformer = mlir::makeOptimizingTransformer(3 /*optLevel*/, 0 /*sizeLevel*/);
auto maybeEngine = mlir::ExecutionEngine::create(m_module.get(), mlir_pm, llvm_transformer);
NGRAPH_ASSERT(maybeEngine) << "failed to construct an execution engine";
m_engine = std::move(maybeEngine.get());
// Invoke the JIT-compiled function with the arguments. Note that, for API
// uniformity reasons, it takes a list of type-erased pointers to arguments.
// Please, note that 'invoke' method is overloaded with a parameter pack version.
// Make sure the MutableArrayRef version is invoked.
auto invocationResult =
m_engine->invoke("main", llvm::MutableArrayRef<void*>(m_invoke_args));
NGRAPH_ASSERT(!invocationResult) << "JIT invocation of 'main' failed\n";
} }
m_builder->create<mlir::NGReturnOp>(mlir::UnknownLoc::get(&m_context), value_list);
}
void MLIRCompiler::cleanup() void MLIRCompiler::bind_arguments()
{
NGRAPH_ASSERT(m_module && "MLIR module is not ready.");
mlir::Function* func = m_module->getNamedFunction("main");
NGRAPH_ASSERT(func && !func->getBlocks().empty()) << "Function not found";
// Create list with a type-erased double pointer for each invocation arguments.
// We currently use 'allocateMemRefArguments', which creates a
// SmallVector<StaticFloatMemref*>. StaticFloatMemref is just a struct with the
// actual pointer to the data.
// create MemRef args
auto expected_arguments = allocate_memref_args(func);
NGRAPH_ASSERT(expected_arguments.size()) << "Arguments can't be created";
m_invoke_args = std::move(expected_arguments);
NGRAPH_ASSERT(m_invoke_args.size() == m_external_tensors.size())
<< "Number of external tensors doesn't match number of function arguments";
// Assign external tensor pointers to invocation arguments.
for (size_t i = 0, num_args = m_invoke_args.size(); i < num_args; ++i)
{ {
// Free void double pointer arguments without freeing external tensor data. ((mlir::StaticFloatMemRef*)m_invoke_args[i])->data = (float*)m_external_tensors[i];
for (auto* arg : m_invoke_args)
{
free(arg);
}
// Free MLIR function builder.
if (m_builder)
m_builder.reset(nullptr);
// Free allocated memory for JIT'ed code temps
m_mem_mgr.freeAll();
} }
SmallVector<void*, 8> MLIRCompiler::allocate_memref_args(mlir::Function* func) // Add pointer to memory manager
// malloc here since that's what allocateMemRefArguments use
// TODO (nmostafa): Better way of doing this ? Use builder allocator ?
MLIRMemMgr** mem_mgr_arg = reinterpret_cast<MLIRMemMgr**>(malloc(sizeof(void*)));
*mem_mgr_arg = &get_mem_mgr();
// inserting memory manager ptr in right location ?
NGRAPH_ASSERT(m_invoke_args.size() == get_mem_mgr_arg_id(func));
m_invoke_args.push_back(static_cast<void*>(mem_mgr_arg));
}
void MLIRCompiler::execute()
{
NGRAPH_ASSERT(m_module && "MLIR module is not ready.");
// Lower Standard dialect to LLVM dialect.
auto converter = mlir::createStdToLLVMConverter();
auto r = converter->convert(m_module.get());
(void)r;
NGRAPH_ASSERT(succeeded(r)) << "second conversion failed";
// Initialize LLVM targets.
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
// Create an MLIR execution engine. We use a null MLIR pass manager for now to make sure we
// don't run MLIR passes that were already run. We also pass a default transformer to run
// LLVM optimizations at level 3.
mlir::PassManager* mlir_pm = nullptr;
auto llvm_transformer = mlir::makeOptimizingTransformer(3 /*optLevel*/, 0 /*sizeLevel*/);
auto maybeEngine = mlir::ExecutionEngine::create(m_module.get(), mlir_pm, llvm_transformer);
NGRAPH_ASSERT(maybeEngine) << "failed to construct an execution engine";
m_engine = std::move(maybeEngine.get());
// Invoke the JIT-compiled function with the arguments. Note that, for API
// uniformity reasons, it takes a list of type-erased pointers to arguments.
// Please, note that 'invoke' method is overloaded with a parameter pack version.
// Make sure the MutableArrayRef version is invoked.
auto invocationResult = m_engine->invoke("main", llvm::MutableArrayRef<void*>(m_invoke_args));
NGRAPH_ASSERT(!invocationResult) << "JIT invocation of 'main' failed\n";
}
void MLIRCompiler::cleanup()
{
// Free void double pointer arguments without freeing external tensor data.
for (auto* arg : m_invoke_args)
{ {
SmallVector<void*, 8> args; free(arg);
args.reserve(func->getNumArguments());
for (const auto& arg : func->getArguments())
{
auto descriptor = allocate_memref_descriptor(arg->getType());
if (!descriptor)
continue;
args.push_back(descriptor);
}
return args;
} }
mlir::StaticFloatMemRef* MLIRCompiler::allocate_memref_descriptor(mlir::Type type) // Free MLIR function builder.
if (m_builder)
m_builder.reset(nullptr);
// Free allocated memory for JIT'ed code temps
m_mem_mgr.freeAll();
}
SmallVector<void*, 8> MLIRCompiler::allocate_memref_args(mlir::Function* func)
{
SmallVector<void*, 8> args;
args.reserve(func->getNumArguments());
for (const auto& arg : func->getArguments())
{ {
auto memRefType = type.dyn_cast<mlir::MemRefType>(); auto descriptor = allocate_memref_descriptor(arg->getType());
if (!memRefType)
return nullptr; if (!descriptor)
NGRAPH_ASSERT(memRefType.getNumDynamicDims() == 0) << "No support for dynamic shapes"; continue;
args.push_back(descriptor);
// We only use StaticFloatMemRef because that's what MLIR currently offers.
// We should expand this with different types and dynamic MemRefs
auto* descriptor =
reinterpret_cast<mlir::StaticFloatMemRef*>(malloc(sizeof(mlir::StaticFloatMemRef)));
descriptor->data = nullptr;
return descriptor;
} }
return args;
}
mlir::StaticFloatMemRef* MLIRCompiler::allocate_memref_descriptor(mlir::Type type)
{
auto memRefType = type.dyn_cast<mlir::MemRefType>();
if (!memRefType)
return nullptr;
NGRAPH_ASSERT(memRefType.getNumDynamicDims() == 0) << "No support for dynamic shapes";
// We only use StaticFloatMemRef because that's what MLIR currently offers.
// We should expand this with different types and dynamic MemRefs
auto* descriptor =
reinterpret_cast<mlir::StaticFloatMemRef*>(malloc(sizeof(mlir::StaticFloatMemRef)));
descriptor->data = nullptr;
return descriptor;
} }
...@@ -18,11 +18,8 @@ ...@@ -18,11 +18,8 @@
#include "ops.hpp" #include "ops.hpp"
#include "type.hpp" #include "type.hpp"
using namespace ngraph::runtime::ngmlir; using namespace mlir;
/// Register a dialect and its types
/// Usage:
/// mlir::registerDialect<ngraph::runtime::ngmlir::Dialect>();
NGDialect::NGDialect(mlir::MLIRContext* ctx) NGDialect::NGDialect(mlir::MLIRContext* ctx)
: mlir::Dialect("ng", ctx) : mlir::Dialect("ng", ctx)
{ {
......
...@@ -23,24 +23,17 @@ ...@@ -23,24 +23,17 @@
#include "mlir/IR/TypeSupport.h" #include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h" #include "mlir/IR/Types.h"
#include "ngraph/assertion.hpp" #include "ngraph/assertion.hpp"
namespace mlir
namespace ngraph
{ {
namespace runtime class NGDialect : public mlir::Dialect
{ {
namespace ngmlir public:
explicit NGDialect(mlir::MLIRContext* ctx);
mlir::Type parseType(llvm::StringRef tyData, mlir::Location loc) const override
{ {
class NGDialect : public mlir::Dialect NGRAPH_ASSERT(0) << "Unsupported type parsing.";
{ return mlir::Type();
public:
explicit NGDialect(mlir::MLIRContext* ctx);
mlir::Type parseType(llvm::StringRef tyData, mlir::Location loc) const override
{
NGRAPH_ASSERT(0) << "Unsupported type parsing.";
return mlir::Type();
}
void printType(mlir::Type type, llvm::raw_ostream& os) const override;
};
} }
} void printType(mlir::Type type, llvm::raw_ostream& os) const override;
};
} }
...@@ -26,83 +26,69 @@ using llvm::raw_string_ostream; ...@@ -26,83 +26,69 @@ using llvm::raw_string_ostream;
using llvm::SmallVector; using llvm::SmallVector;
using llvm::StringRef; using llvm::StringRef;
using llvm::Twine; using llvm::Twine;
using namespace mlir;
namespace ngraph // TODO:
// - Move verifiers and other OP helpers (e.g. getSomeAttribute()) to separate files
//
// - Op helpers: Since it is not possible to add arbitrary code (and would complicate the .td file)
// to Ops classes, we will add helper classes with static methods for each Op that needs it
// Additional verification methods
// Tensor type checks are already verified by the caller of these methods
template <typename T>
static mlir::LogicalResult verifyUnaryArithOp(T* op)
{ {
namespace runtime // TODO: Check matching element types
{ return mlir::success();
namespace ngmlir }
{
// TODO:
// - Move verifiers and other OP helpers (e.g. getSomeAttribute()) to separate files
//
// - Op helpers: Since it is not possible to add arbitrary code (and would complicate the .td file)
// to Ops classes, we will add helper classes with static methods for each Op that needs it
// Additional verification methods
// Tensor type checks are already verified by the caller of these methods
template <typename T>
static mlir::LogicalResult verifyUnaryArithOp(T* op)
{
// TODO: Check matching element types
return mlir::success();
}
// Additional verification methods // Additional verification methods
// Tensor type checks are already verified by the caller of these methods // Tensor type checks are already verified by the caller of these methods
template <typename T> template <typename T>
static mlir::LogicalResult verifyBinaryArithOp(T* op) static mlir::LogicalResult verifyBinaryArithOp(T* op)
{ {
// TODO: Check matching element types // TODO: Check matching element types
return mlir::success(); return mlir::success();
} }
template <typename T> template <typename T>
static mlir::LogicalResult verifyOp(T* op) static mlir::LogicalResult verifyOp(T* op)
{ {
return op->emitOpError("Unsupported verifier for this operation"); return op->emitOpError("Unsupported verifier for this operation");
} }
// Per op specializations // Per op specializations
template <> template <>
mlir::LogicalResult verifyOp<NGMatMulBiasOp>(NGMatMulBiasOp* op) mlir::LogicalResult verifyOp<NGMatMulBiasOp>(NGMatMulBiasOp* op)
{ {
// Verify that we have 2 operands // Verify that we have 2 operands
// Bias operand must be null for now (not implemented) // Bias operand must be null for now (not implemented)
if (op->getNumOperands() != 2) if (op->getNumOperands() != 2)
{ {
std::stringstream ss; std::stringstream ss;
ss << "Unexpected MatmulBiasOp with " << op->getNumOperands() ss << "Unexpected MatmulBiasOp with " << op->getNumOperands()
<< " operands. 3 operands expected"; << " operands. 3 operands expected";
return op->emitOpError(ss.str()); return op->emitOpError(ss.str());
} }
// Verify that operand types are supported. // Verify that operand types are supported.
auto op0_tensor_ty = op->getOperand(0)->getType().cast<NGTensorType>(); auto op0_tensor_ty = op->getOperand(0)->getType().cast<NGTensorType>();
auto op1_tensor_ty = op->getOperand(1)->getType().cast<NGTensorType>(); auto op1_tensor_ty = op->getOperand(1)->getType().cast<NGTensorType>();
// Verify that operand shapes are supported. // Verify that operand shapes are supported.
if (op0_tensor_ty.getRank() != 2 || op1_tensor_ty.getRank() != 2) if (op0_tensor_ty.getRank() != 2 || op1_tensor_ty.getRank() != 2)
{ {
return op->emitOpError( return op->emitOpError(
"Unsupported number of dimensions. Only 2D tensors are supported in " "Unsupported number of dimensions. Only 2D tensors are supported in "
"MatmulBiasOp"); "MatmulBiasOp");
} }
// TODO(dcab): Improve verification: matching types, proper shapes, etc. // TODO(dcab): Improve verification: matching types, proper shapes, etc.
return mlir::success(); return mlir::success();
} }
}
}
using namespace mlir; namespace mlir
namespace runtime {
{
namespace ngmlir
{
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "ops.cpp.inc" #include "ops.cpp.inc"
}
}
} }
...@@ -22,19 +22,8 @@ ...@@ -22,19 +22,8 @@
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "mlir/Support/STLExtras.h" #include "mlir/Support/STLExtras.h"
namespace ngraph namespace mlir
{ {
namespace runtime
{
namespace ngmlir
{
// TODO: We shouldn't have this here, but we need to expose mlir types for the .inc file to use
// we cannot forward declare the mlir types since they rely on the Ops we are defining (see. Op<NGAddOp, ...>)
//
// Other ways to avoid namespace pollution ?
using namespace mlir;
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "ops.h.inc" #include "ops.h.inc"
}
}
} }
...@@ -40,7 +40,7 @@ include "mlir/IR/OpBase.td" ...@@ -40,7 +40,7 @@ include "mlir/IR/OpBase.td"
// This defines records equivalent to NGraph types. It doesn't generate code. // This defines records equivalent to NGraph types. It doesn't generate code.
// This is used as a type in the DAG input/outputs. // This is used as a type in the DAG input/outputs.
// Constraints (CPred) are used to type-check args/results of that type during op verification // Constraints (CPred) are used to type-check args/results of that type during op verification
def NG_TensorType : Type<CPred<"{0}.isa<ngraph::runtime::ngmlir::NGTensorType>()">, def NG_TensorType : Type<CPred<"{0}.isa<mlir::NGTensorType>()">,
"NGraph Tensor Type">; "NGraph Tensor Type">;
// A generic un-typed MemRef. Used for Fake instructions inserted during dialect lowering // A generic un-typed MemRef. Used for Fake instructions inserted during dialect lowering
......
...@@ -31,55 +31,51 @@ using llvm::SmallVector; ...@@ -31,55 +31,51 @@ using llvm::SmallVector;
using llvm::StringRef; using llvm::StringRef;
using llvm::Twine; using llvm::Twine;
namespace ngraph using namespace mlir;
{
using namespace runtime::ngmlir;
unsigned NGIntegerType::getWidth() const unsigned NGIntegerType::getWidth() const
{
switch (getKind())
{ {
switch (getKind()) case NG_I8_TYPE_ID:
{ case NG_U8_TYPE_ID: return 8;
case NG_I8_TYPE_ID: case NG_I16_TYPE_ID:
case NG_U8_TYPE_ID: return 8; case NG_U16_TYPE_ID: return 16;
case NG_I16_TYPE_ID: case NG_I32_TYPE_ID:
case NG_U16_TYPE_ID: return 16; case NG_U32_TYPE_ID: return 32;
case NG_I32_TYPE_ID: case NG_I64_TYPE_ID:
case NG_U32_TYPE_ID: return 32; case NG_U64_TYPE_ID: return 64;
case NG_I64_TYPE_ID: default: NGRAPH_FAIL() << "Invalid type ID";
case NG_U64_TYPE_ID: return 64;
default: NGRAPH_FAIL() << "Invalid type ID";
}
return 0;
} }
return 0;
}
bool NGIntegerType::isSigned() const bool NGIntegerType::isSigned() const
{
switch (getKind())
{ {
switch (getKind()) case NG_I8_TYPE_ID:
{ case NG_I16_TYPE_ID:
case NG_I8_TYPE_ID: case NG_I32_TYPE_ID:
case NG_I16_TYPE_ID: case NG_I64_TYPE_ID: return true;
case NG_I32_TYPE_ID: case NG_U8_TYPE_ID:
case NG_I64_TYPE_ID: return true; case NG_U16_TYPE_ID:
case NG_U8_TYPE_ID: case NG_U32_TYPE_ID:
case NG_U16_TYPE_ID: case NG_U64_TYPE_ID: return false;
case NG_U32_TYPE_ID: default: NGRAPH_FAIL() << "Invalid type ID";
case NG_U64_TYPE_ID: return false;
default: NGRAPH_FAIL() << "Invalid type ID";
}
return false;
} }
return false;
}
/// Creates TensorType objects. They all point to the same storage if /// Creates TensorType objects. They all point to the same storage if
/// element type and shape are the same. /// element type and shape are the same.
NGTensorType NGTensorType::get(mlir::MLIRContext* context, EltType eltType, Shape shape) NGTensorType NGTensorType::get(MLIRContext* context, EltType eltType, Shape shape)
{ {
return Base::get(context, NGTypeKind::NG_TENSOR_TYPE_ID, eltType, shape); return Base::get(context, NGTypeKind::NG_TENSOR_TYPE_ID, eltType, shape);
} }
mlir::MemRefType NGTensorType::toMemref() MemRefType NGTensorType::toMemref()
{ {
auto memRefType = auto memRefType = MemRefType::get(getShape(), getElementType(), {/* no map used */}, 0);
mlir::MemRefType::get(getShape(), getElementType(), {/* no map used */}, 0); return memRefType;
return memRefType;
}
} }
...@@ -23,238 +23,228 @@ ...@@ -23,238 +23,228 @@
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeSupport.h" #include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h" #include "mlir/IR/Types.h"
namespace ngraph namespace mlir
{ {
namespace runtime using llvm::raw_ostream;
{
namespace ngmlir
{
using llvm::raw_ostream;
enum NGTypeKind
{
// The enum starts at the range reserved for this dialect.
// These values are pre-defined in MLIR lib and not configurable from here.
NG_TYPE = mlir::Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_0_TYPE,
// Element types that are added by the dialect.
// Other types are just re-use of std dialect types.
NG_FIRST_INT_TYPE_ID,
NG_I8_TYPE_ID = NG_FIRST_INT_TYPE_ID,
NG_I16_TYPE_ID,
NG_I32_TYPE_ID,
NG_I64_TYPE_ID,
NG_U8_TYPE_ID,
NG_U16_TYPE_ID,
NG_U32_TYPE_ID,
NG_U64_TYPE_ID,
NG_LAST_INT_TYPE_ID = NG_U64_TYPE_ID,
NG_BOOL_TYPE_ID,
// Tensor type
NG_TENSOR_TYPE_ID
};
// reuse std float types as-is
using NGFloatType = mlir::FloatType;
/// Integer type. It represents an integer of width 8,16,32,64. Signed or not.
class NGIntegerType : public mlir::Type::TypeBase<NGIntegerType, mlir::Type>
{
public:
using Base::Base;
static NGIntegerType get(NGTypeKind kind, mlir::MLIRContext* context) enum NGTypeKind
{ {
NGRAPH_ASSERT(kindof(kind)) << "Not an integer kind."; // The enum starts at the range reserved for this dialect.
return Base::get(context, kind); // These values are pre-defined in MLIR lib and not configurable from here.
} NG_TYPE = mlir::Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_0_TYPE,
/// Create signed Int8 // Element types that are added by the dialect.
static NGIntegerType getInt8(mlir::MLIRContext* ctx) // Other types are just re-use of std dialect types.
{ NG_FIRST_INT_TYPE_ID,
return get(NGTypeKind::NG_I8_TYPE_ID, ctx); NG_I8_TYPE_ID = NG_FIRST_INT_TYPE_ID,
} NG_I16_TYPE_ID,
/// Create signed Int16 NG_I32_TYPE_ID,
static NGIntegerType getInt16(mlir::MLIRContext* ctx) NG_I64_TYPE_ID,
{ NG_U8_TYPE_ID,
return get(NGTypeKind::NG_I16_TYPE_ID, ctx); NG_U16_TYPE_ID,
} NG_U32_TYPE_ID,
/// Create signed Int32 NG_U64_TYPE_ID,
static NGIntegerType getInt32(mlir::MLIRContext* ctx) NG_LAST_INT_TYPE_ID = NG_U64_TYPE_ID,
{ NG_BOOL_TYPE_ID,
return get(NGTypeKind::NG_I32_TYPE_ID, ctx); // Tensor type
} NG_TENSOR_TYPE_ID
/// Create signed Int64 };
static NGIntegerType getInt64(mlir::MLIRContext* ctx)
{ // reuse std float types as-is
return get(NGTypeKind::NG_I64_TYPE_ID, ctx); using NGFloatType = mlir::FloatType;
}
/// Create unsigned Int8 /// Integer type. It represents an integer of width 8,16,32,64. Signed or not.
static NGIntegerType getUInt8(mlir::MLIRContext* ctx) class NGIntegerType : public mlir::Type::TypeBase<NGIntegerType, mlir::Type>
{ {
return get(NGTypeKind::NG_U8_TYPE_ID, ctx); public:
} using Base::Base;
/// Create unsigned Int16
static NGIntegerType getUInt16(mlir::MLIRContext* ctx)
{
return get(NGTypeKind::NG_U16_TYPE_ID, ctx);
}
/// Create unsigned Int32
static NGIntegerType getUInt32(mlir::MLIRContext* ctx)
{
return get(NGTypeKind::NG_U32_TYPE_ID, ctx);
}
/// Create unsigned Int64
static NGIntegerType getUInt64(mlir::MLIRContext* ctx)
{
return get(NGTypeKind::NG_U64_TYPE_ID, ctx);
}
/// RTTI support. So we can do obj->isa<NGIntegerType>()
static bool kindof(unsigned kind)
{
return kind >= NGTypeKind::NG_FIRST_INT_TYPE_ID &&
kind <= NGTypeKind::NG_LAST_INT_TYPE_ID;
}
/// Return the bitwidth of this integer type.
unsigned getWidth() const;
/// Convert to equivalent std type
/// std types are sign-agnostic.
mlir::Type toStdType() const
{
return mlir::IntegerType::get(getWidth(), getContext());
}
/// Check if signed type
bool isSigned() const;
/// Check if Int8
bool isInt8() const { return getKind() == NG_I8_TYPE_ID; }
/// Check if UInt8
bool isUInt8() const { return getKind() == NG_U8_TYPE_ID; }
/// Check if Int16
bool isInt16() const { return getKind() == NG_I16_TYPE_ID; }
/// Check if UInt16
bool isUInt16() const { return getKind() == NG_U16_TYPE_ID; }
/// Check if Int32
bool isInt32() const { return getKind() == NG_I32_TYPE_ID; }
/// Check if UInt32
bool isUInt32() const { return getKind() == NG_U32_TYPE_ID; }
/// Check if Int64
bool isInt64() const { return getKind() == NG_I64_TYPE_ID; }
/// Check if UInt64
bool isUInt64() const { return getKind() == NG_U64_TYPE_ID; }
// Delete convenience methods inherited from MLIR Type class.
// This would avoid confusion if we do something like this and get false.
//
// if (type->cast<NGIntegerType>()->isInteger(32)) {}
//
// Those helpers use type id, and since we have our own Integer type id, they
// don't apply.
bool isInteger(unsigned width) const = delete;
unsigned getIntOrFloatBitWidth() const = delete;
bool isIntOrIndex() const = delete;
bool isIntOrIndexOrFloat() const = delete;
bool isIntOrFloat() const = delete;
};
/// Boolean Type. static NGIntegerType get(NGTypeKind kind, mlir::MLIRContext* context)
class NGBoolType : public mlir::Type::TypeBase<NGBoolType, mlir::Type> {
{ NGRAPH_ASSERT(kindof(kind)) << "Not an integer kind.";
public: return Base::get(context, kind);
using Base::Base; }
static NGBoolType get(NGTypeKind kind, mlir::MLIRContext* context) /// Create signed Int8
{ static NGIntegerType getInt8(mlir::MLIRContext* ctx)
NGRAPH_ASSERT(kindof(kind)) << "Not a bool type."; {
return Base::get(context, kind); return get(NGTypeKind::NG_I8_TYPE_ID, ctx);
} }
/// Create signed Int16
static NGIntegerType getInt16(mlir::MLIRContext* ctx)
{
return get(NGTypeKind::NG_I16_TYPE_ID, ctx);
}
/// Create signed Int32
static NGIntegerType getInt32(mlir::MLIRContext* ctx)
{
return get(NGTypeKind::NG_I32_TYPE_ID, ctx);
}
/// Create signed Int64
static NGIntegerType getInt64(mlir::MLIRContext* ctx)
{
return get(NGTypeKind::NG_I64_TYPE_ID, ctx);
}
/// Create unsigned Int8
static NGIntegerType getUInt8(mlir::MLIRContext* ctx)
{
return get(NGTypeKind::NG_U8_TYPE_ID, ctx);
}
/// Create unsigned Int16
static NGIntegerType getUInt16(mlir::MLIRContext* ctx)
{
return get(NGTypeKind::NG_U16_TYPE_ID, ctx);
}
/// Create unsigned Int32
static NGIntegerType getUInt32(mlir::MLIRContext* ctx)
{
return get(NGTypeKind::NG_U32_TYPE_ID, ctx);
}
/// Create unsigned Int64
static NGIntegerType getUInt64(mlir::MLIRContext* ctx)
{
return get(NGTypeKind::NG_U64_TYPE_ID, ctx);
}
static bool kindof(unsigned kind) { return kind == NGTypeKind::NG_BOOL_TYPE_ID; } /// RTTI support. So we can do obj->isa<NGIntegerType>()
static NGBoolType get(mlir::MLIRContext* ctx) { return get(NG_BOOL_TYPE_ID, ctx); } static bool kindof(unsigned kind)
/// Convert to equivalent std type. Integer of width 1 in that case {
mlir::Type toStdType() const { return mlir::IntegerType::get(1, getContext()); } return kind >= NGTypeKind::NG_FIRST_INT_TYPE_ID &&
}; kind <= NGTypeKind::NG_LAST_INT_TYPE_ID;
}
// Note that dialect types don't add new data members, so always possible /// Return the bitwidth of this integer type.
// to use NG or std types here unsigned getWidth() const;
using EltType = mlir::Type;
// TODO: Can we use ngraph::shape here (given the hashing requirements) /// Convert to equivalent std type
using Shape = llvm::ArrayRef<int64_t>; /// std types are sign-agnostic.
mlir::Type toStdType() const { return mlir::IntegerType::get(getWidth(), getContext()); }
/// Check if signed type
bool isSigned() const;
/// Check if Int8
bool isInt8() const { return getKind() == NG_I8_TYPE_ID; }
/// Check if UInt8
bool isUInt8() const { return getKind() == NG_U8_TYPE_ID; }
/// Check if Int16
bool isInt16() const { return getKind() == NG_I16_TYPE_ID; }
/// Check if UInt16
bool isUInt16() const { return getKind() == NG_U16_TYPE_ID; }
/// Check if Int32
bool isInt32() const { return getKind() == NG_I32_TYPE_ID; }
/// Check if UInt32
bool isUInt32() const { return getKind() == NG_U32_TYPE_ID; }
/// Check if Int64
bool isInt64() const { return getKind() == NG_I64_TYPE_ID; }
/// Check if UInt64
bool isUInt64() const { return getKind() == NG_U64_TYPE_ID; }
// Delete convenience methods inherited from MLIR Type class.
// This would avoid confusion if we do something like this and get false.
//
// if (type->cast<NGIntegerType>()->isInteger(32)) {}
//
// Those helpers use type id, and since we have our own Integer type id, they
// don't apply.
bool isInteger(unsigned width) const = delete;
unsigned getIntOrFloatBitWidth() const = delete;
bool isIntOrIndex() const = delete;
bool isIntOrIndexOrFloat() const = delete;
bool isIntOrFloat() const = delete;
};
/// Boolean Type.
class NGBoolType : public mlir::Type::TypeBase<NGBoolType, mlir::Type>
{
public:
using Base::Base;
static NGBoolType get(NGTypeKind kind, mlir::MLIRContext* context)
{
NGRAPH_ASSERT(kindof(kind)) << "Not a bool type.";
return Base::get(context, kind);
}
/// Tensor Type storage. There is a unique instance per type attributes. static bool kindof(unsigned kind) { return kind == NGTypeKind::NG_BOOL_TYPE_ID; }
/// Tensor Type is combination of the element type and shape. Each different static NGBoolType get(mlir::MLIRContext* ctx) { return get(NG_BOOL_TYPE_ID, ctx); }
/// shape is a unique type. /// Convert to equivalent std type. Integer of width 1 in that case
struct NGTensorTypeStorage : public mlir::TypeStorage mlir::Type toStdType() const { return mlir::IntegerType::get(1, getContext()); }
{ };
// Tensor key is its type and shape.
// This is called when the user requests a specific tensor type // Note that dialect types don't add new data members, so always possible
using KeyTy = std::tuple<EltType, Shape>; // to use NG or std types here
using EltType = mlir::Type;
// TODO: Can we use ngraph::shape here (given the hashing requirements)
using Shape = llvm::ArrayRef<int64_t>;
/// Tensor Type storage. There is a unique instance per type attributes.
/// Tensor Type is combination of the element type and shape. Each different
/// shape is a unique type.
struct NGTensorTypeStorage : public mlir::TypeStorage
{
// Tensor key is its type and shape.
// This is called when the user requests a specific tensor type
using KeyTy = std::tuple<EltType, Shape>;
static unsigned hashKey(const KeyTy& key) static unsigned hashKey(const KeyTy& key)
{ {
return llvm::hash_combine(std::get<0>(key), std::get<1>(key)); return llvm::hash_combine(std::get<0>(key), std::get<1>(key));
} }
bool operator==(const KeyTy& key) const bool operator==(const KeyTy& key) const
{ {
return key == KeyTy(getElementType(), getShape()); return key == KeyTy(getElementType(), getShape());
} }
static NGTensorTypeStorage* construct(mlir::TypeStorageAllocator& allocator, static NGTensorTypeStorage* construct(mlir::TypeStorageAllocator& allocator,
const KeyTy& key) const KeyTy& key)
{ {
// Deep copy the type shape over to MLIR context // Deep copy the type shape over to MLIR context
EltType eltType = std::get<0>(key); EltType eltType = std::get<0>(key);
Shape shape = allocator.copyInto(std::get<1>(key)); Shape shape = allocator.copyInto(std::get<1>(key));
auto* storage = allocator.allocate<NGTensorTypeStorage>(); auto* storage = allocator.allocate<NGTensorTypeStorage>();
return new (storage) NGTensorTypeStorage(eltType, shape); return new (storage) NGTensorTypeStorage(eltType, shape);
} }
Shape getShape() const { return m_shape; } Shape getShape() const { return m_shape; }
EltType getElementType() const { return m_eltType; } EltType getElementType() const { return m_eltType; }
private: private:
NGTensorTypeStorage(EltType eltType, Shape shape) NGTensorTypeStorage(EltType eltType, Shape shape)
: m_eltType(eltType) : m_eltType(eltType)
, m_shape(shape) , m_shape(shape)
{ {
} }
private: private:
EltType m_eltType; EltType m_eltType;
Shape m_shape; Shape m_shape;
}; };
/// NGraph Tensor Type /// NGraph Tensor Type
class NGTensorType class NGTensorType : public mlir::Type::TypeBase<NGTensorType, mlir::Type, NGTensorTypeStorage>
: public mlir::Type::TypeBase<NGTensorType, mlir::Type, NGTensorTypeStorage> {
public:
using Base::Base;
EltType getElementType() const { return getImpl()->getElementType(); }
Shape getShape() const { return getImpl()->getShape(); }
/// Tensor Rank. Static shape only for now
int getRank() { return getShape().size(); }
/// Computes tensor size in bytes
size_t getSizeInBytes()
{
size_t s = 1;
auto shape = getShape();
for (auto i = 0; i < getRank(); i++)
{ {
public: // no dynamic dims
using Base::Base; if (shape[i] == -1)
EltType getElementType() const { return getImpl()->getElementType(); } return -1;
Shape getShape() const { return getImpl()->getShape(); } s *= shape[i];
/// Tensor Rank. Static shape only for now }
int getRank() { return getShape().size(); } // Multiply times element size
/// Computes tensor size in bytes return s * llvm::divideCeil(getElementType().getIntOrFloatBitWidth(), 8);
size_t getSizeInBytes()
{
size_t s = 1;
auto shape = getShape();
for (auto i = 0; i < getRank(); i++)
{
// no dynamic dims
if (shape[i] == -1)
return -1;
s *= shape[i];
}
// Multiply times element size
return s * llvm::divideCeil(getElementType().getIntOrFloatBitWidth(), 8);
}
/// convert to memref native MLIR type. Used for lowering.
mlir::MemRefType toMemref();
/// create a unique tensor type based on element type and shape.
static NGTensorType get(mlir::MLIRContext* context, EltType eltType, Shape shape);
/// for llvm RTTI
static bool kindof(unsigned kind) { return kind == NGTypeKind::NG_TENSOR_TYPE_ID; }
};
} }
} /// convert to memref native MLIR type. Used for lowering.
mlir::MemRefType toMemref();
/// create a unique tensor type based on element type and shape.
static NGTensorType get(mlir::MLIRContext* context, EltType eltType, Shape shape);
/// for llvm RTTI
static bool kindof(unsigned kind) { return kind == NGTypeKind::NG_TENSOR_TYPE_ID; }
};
} }
...@@ -131,7 +131,7 @@ namespace ...@@ -131,7 +131,7 @@ namespace
// we find out output values by looking at returned values // we find out output values by looking at returned values
// any return should return all outputs of the subgraph // any return should return all outputs of the subgraph
f->walk<ngmlir::NGReturnOp>([this, &outputCount](ngmlir::NGReturnOp ret) { f->walk<NGReturnOp>([this, &outputCount](NGReturnOp ret) {
for (unsigned i = 0; i < ret.getNumOperands(); i++) for (unsigned i = 0; i < ret.getNumOperands(); i++)
{ {
this->m_outputValueMap.insert(std::pair<Value*, unsigned>(ret.getOperand(i), i)); this->m_outputValueMap.insert(std::pair<Value*, unsigned>(ret.getOperand(i), i));
...@@ -151,8 +151,8 @@ namespace ...@@ -151,8 +151,8 @@ namespace
// however, due to how DialectConversion framework works, new func is only // however, due to how DialectConversion framework works, new func is only
// materialized after conversion is done (rewriter->getFunction, or even rewriter->getInsertionBlock()->getFunction() // materialized after conversion is done (rewriter->getFunction, or even rewriter->getInsertionBlock()->getFunction()
// will give you the original func). This makes it very convoluted to insert instructions at entry block. // will give you the original func). This makes it very convoluted to insert instructions at entry block.
auto op = rewriter->create<ngmlir::NGFakeInputOp>(rewriter->getUnknownLoc(), auto op = rewriter->create<NGFakeInputOp>(rewriter->getUnknownLoc(),
IndexType::get(getModule().getContext())); IndexType::get(getModule().getContext()));
// will be fixed later to read passed arg instead. // will be fixed later to read passed arg instead.
m_memMgrDefs.push_back(op.getResult()); m_memMgrDefs.push_back(op.getResult());
return op.getResult(); return op.getResult();
...@@ -170,7 +170,7 @@ namespace ...@@ -170,7 +170,7 @@ namespace
if (it != outputMap.end()) if (it != outputMap.end())
{ {
unsigned argId = (*it).second; unsigned argId = (*it).second;
auto fakeOp = rewriter.create<ngmlir::NGFakeInputOp>( auto fakeOp = rewriter.create<NGFakeInputOp>(
op->getLoc(), op->getLoc(),
m_dialectLowerer.convertType( m_dialectLowerer.convertType(
origResult->getType()) /* convert to lowered type */ origResult->getType()) /* convert to lowered type */
...@@ -183,7 +183,7 @@ namespace ...@@ -183,7 +183,7 @@ namespace
} }
else else
{ {
auto tensorType = origResult->getType().cast<ngmlir::NGTensorType>(); auto tensorType = origResult->getType().cast<NGTensorType>();
auto callBackFunc = getCallDecl("__mlir_allocate", auto callBackFunc = getCallDecl("__mlir_allocate",
{rewriter.getIndexType(), rewriter.getIndexType()}, {rewriter.getIndexType(), rewriter.getIndexType()},
{tensorType.toMemref()}, {tensorType.toMemref()},
...@@ -237,8 +237,7 @@ namespace ...@@ -237,8 +237,7 @@ namespace
for (auto value : m_loweredOutputValues) for (auto value : m_loweredOutputValues)
{ {
auto op = value->getDefiningOp(); auto op = value->getDefiningOp();
NGRAPH_ASSERT(op->isa<ngmlir::NGFakeInputOp>()) NGRAPH_ASSERT(op->isa<NGFakeInputOp>()) << "output value not defined by fake output?";
<< "output value not defined by fake output?";
value->replaceAllUsesWith(entryBlock->getArgument(oldFuncType.getNumInputs() + i)); value->replaceAllUsesWith(entryBlock->getArgument(oldFuncType.getNumInputs() + i));
op->erase(); op->erase();
i++; i++;
...@@ -269,23 +268,23 @@ namespace ...@@ -269,23 +268,23 @@ namespace
// NGDialect converters // NGDialect converters
Type DialectLowerer::convertType(Type t) Type DialectLowerer::convertType(Type t)
{ {
if (auto tensor = t.dyn_cast<ngmlir::NGTensorType>()) if (auto tensor = t.dyn_cast<NGTensorType>())
{ {
return tensor.toMemref(); return tensor.toMemref();
} }
// element type // element type
if (auto type = t.dyn_cast<ngmlir::NGFloatType>()) if (auto type = t.dyn_cast<NGFloatType>())
{ {
// Float // Float
// float types are already std type // float types are already std type
return type; return type;
} }
if (auto type = t.dyn_cast<ngmlir::NGIntegerType>()) if (auto type = t.dyn_cast<NGIntegerType>())
{ {
// map it to std type // map it to std type
return type.toStdType(); return type.toStdType();
} }
if (auto type = t.dyn_cast<ngmlir::NGBoolType>()) if (auto type = t.dyn_cast<NGBoolType>())
{ {
return type.toStdType(); return type.toStdType();
} }
...@@ -298,7 +297,7 @@ namespace ...@@ -298,7 +297,7 @@ namespace
ArrayRef<Value*> operands, ArrayRef<Value*> operands,
FuncBuilder& rewriter) const FuncBuilder& rewriter) const
{ {
auto add = op->cast<ngmlir::NGAddOp>(); auto add = op->cast<NGAddOp>();
auto loc = add.getLoc(); auto loc = add.getLoc();
Value *origResult, *newResult; Value *origResult, *newResult;
...@@ -335,7 +334,7 @@ namespace ...@@ -335,7 +334,7 @@ namespace
ArrayRef<Value*> operands, ArrayRef<Value*> operands,
FuncBuilder& rewriter) const FuncBuilder& rewriter) const
{ {
auto matmul = op->cast<ngmlir::NGMatMulBiasOp>(); auto matmul = op->cast<NGMatMulBiasOp>();
auto loc = matmul.getLoc(); auto loc = matmul.getLoc();
NGRAPH_ASSERT(operands.size() == 2) << "Bias is not supported yet in MatmulBias operation"; NGRAPH_ASSERT(operands.size() == 2) << "Bias is not supported yet in MatmulBias operation";
...@@ -406,16 +405,10 @@ namespace ...@@ -406,16 +405,10 @@ namespace
} }
} }
namespace ngraph namespace mlir
{ {
namespace runtime Pass* createDialectLoweringPass(ngraph::runtime::ngmlir::MLIRCompiler* compiler)
{ {
namespace ngmlir return new DialectLoweringPass(*compiler);
{
Pass* createDialectLoweringPass(MLIRCompiler* compiler)
{
return new DialectLoweringPass(*compiler);
}
}
} }
} }
...@@ -16,9 +16,9 @@ ...@@ -16,9 +16,9 @@
#pragma once #pragma once
#include "contrib/mlir/compiler.hpp"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"
namespace ngraph namespace ngraph
{ {
namespace runtime namespace runtime
...@@ -26,8 +26,10 @@ namespace ngraph ...@@ -26,8 +26,10 @@ namespace ngraph
namespace ngmlir namespace ngmlir
{ {
class MLIRCompiler; class MLIRCompiler;
mlir::Pass* createDialectLoweringPass(MLIRCompiler* compiler);
} }
} }
} }
namespace mlir
{
mlir::Pass* createDialectLoweringPass(ngraph::runtime::ngmlir::MLIRCompiler* compiler);
}
...@@ -22,7 +22,7 @@ class OP##Conversion : public mlir::DialectOpConversion \ ...@@ -22,7 +22,7 @@ class OP##Conversion : public mlir::DialectOpConversion \
{\ {\
public:\ public:\
explicit OP##Conversion(mlir::MLIRContext *context, DialectLoweringPass& pass)\ explicit OP##Conversion(mlir::MLIRContext *context, DialectLoweringPass& pass)\
: mlir::DialectOpConversion(ngraph::runtime::ngmlir::OP::getOperationName(), 1, context),\ : mlir::DialectOpConversion(mlir::OP::getOperationName(), 1, context),\
m_pass(pass)\ m_pass(pass)\
{} \ {} \
SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands, FuncBuilder &rewriter) const override; \ SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands, FuncBuilder &rewriter) const override; \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment