Commit 7b3c323b authored by Nagy Mostafa's avatar Nagy Mostafa Committed by omarkanawi

[MLIR] Fix style in compiler and lowerer files (#3564)

* Fix style in compiler and lowerer files

* Fix comment in headers

* Revert "Fix comment in headers"

This reverts commit d52eed4c1bdf371f3cc7d3f601d9d2b1b0c233e8.

* Fix compiler.* header. Fix code style in other files
parent 76e4485b
......@@ -14,8 +14,8 @@
// limitations under the License.
//*****************************************************************************
// NOTE: This file follows nGraph format style and naming convention since it
// exposes a public API to the rest of nGraph codebase.
// NOTE: This file follows nGraph format style.
// Follows nGraph naming convention for public APIs only, else MLIR naming convention.
#include "compiler.hpp"
......@@ -116,17 +116,17 @@ static llvm::cl::opt<unsigned> clLoopTilingCacheSize(
"-loop-tile-cache-level."));
#define COMPILE_OP_DECL(op_name) \
create_op<op_name>(MLIRCompiler & compiler, const ngraph::Node* ng_node)
createOp<op_name>(MLIRCompiler & compiler, const ngraph::Node* ngNode)
// Default optimization level.
unsigned MLIRCompiler::mlir_opt_level = 2;
unsigned MLIRCompiler::mlirOptLevel = 2;
// Target machine will be properly initialized by `init_mlir`.
std::unique_ptr<llvm::TargetMachine> MLIRCompiler::target_machine;
std::unique_ptr<llvm::TargetMachine> MLIRCompiler::targetMachine;
/// Creates target machine for current host.
static llvm::Expected<std::unique_ptr<llvm::TargetMachine>>
createDefaultTargetMachine(unsigned opt_level)
createDefaultTargetMachine(unsigned optLevel)
{
auto machineBuilder = llvm::orc::JITTargetMachineBuilder::detectHost();
if (!machineBuilder)
......@@ -142,17 +142,17 @@ static llvm::Expected<std::unique_ptr<llvm::TargetMachine>>
// Default, // -O2, -Os
// Aggressive // -O3
// };
machineBuilder->setCodeGenOptLevel((llvm::CodeGenOpt::Level)opt_level);
machineBuilder->setCodeGenOptLevel((llvm::CodeGenOpt::Level)optLevel);
return machineBuilder->createTargetMachine();
}
void MLIRCompiler::init_mlir()
{
// Mutex to safely initialize MLIR.
static std::mutex mlir_init_mutex;
static std::mutex mlirInitMutex;
static bool initialized = false;
std::unique_lock<std::mutex> lock(mlir_init_mutex);
std::unique_lock<std::mutex> lock(mlirInitMutex);
if (!initialized)
{
......@@ -164,18 +164,18 @@ void MLIRCompiler::init_mlir()
llvm::cl::ParseEnvironmentOptions("ngraph", "NGRAPH_MLIR_OPTIONS", "");
// Override default optimization level with macro value.
if (char* opt_level_str = std::getenv("NGRAPH_MLIR_OPT_LEVEL"))
if (char* optLevelStr = std::getenv("NGRAPH_MLIR_OPT_LEVEL"))
{
mlir_opt_level = std::stoi(opt_level_str);
NGRAPH_CHECK(mlir_opt_level >= 0 && mlir_opt_level <= 3, "Invalid optimization level");
mlirOptLevel = std::stoi(optLevelStr);
NGRAPH_CHECK(mlirOptLevel >= 0 && mlirOptLevel <= 3, "Invalid optimization level");
}
// Initialize LLVM targets and target machine for current host.
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
auto expected_target_machine = createDefaultTargetMachine(mlir_opt_level);
NGRAPH_CHECK(expected_target_machine, "Invalid target machine");
target_machine = std::move(*expected_target_machine);
auto expectedTargetMachine = createDefaultTargetMachine(mlirOptLevel);
NGRAPH_CHECK(expectedTargetMachine, "Invalid target machine");
targetMachine = std::move(*expectedTargetMachine);
initialized = true;
}
......@@ -183,97 +183,95 @@ void MLIRCompiler::init_mlir()
void MLIRCompiler::compile()
{
build_ng_dialect_module();
lower_ng_dialect();
buildNgDialectModule();
lowerNgDialect();
}
void MLIRCompiler::run(std::vector<void*>& external_tensors)
void MLIRCompiler::run(std::vector<void*>& externalTensors)
{
bind_arguments(external_tensors);
bindArguments(externalTensors);
execute();
cleanup();
}
// Creates an MLIR module and function with nGraph dialect ops from the input CompiledKernel.
void MLIRCompiler::build_ng_dialect_module()
void MLIRCompiler::buildNgDialectModule()
{
// initialize an empty module
m_module = mlir::ModuleOp::create(mlir::UnknownLoc::get(&m_context));
TypeList args_type_list, result_type_list;
TypeList argsTypeList, resultTypeList;
// Retrieve input and output tensors.
const auto& kernel_inputs = m_compiled_kernel->get_arguments();
const auto& kernel_outputs = m_compiled_kernel->get_kernel_outputs();
NGRAPH_CHECK(kernel_inputs.size() != 0, "Cannot have empty inputs list");
NGRAPH_CHECK(kernel_outputs.size() != 0, "Cannot have empty outputs list");
const auto& kernelInputs = m_compiledKernel->get_arguments();
const auto& kernelOutput = m_compiledKernel->get_kernel_outputs();
NGRAPH_CHECK(kernelInputs.size() != 0, "Cannot have empty inputs list");
NGRAPH_CHECK(kernelOutput.size() != 0, "Cannot have empty outputs list");
for (auto input : kernel_inputs)
for (auto input : kernelInputs)
{
args_type_list.push_back(get_mlir_type(input.get()));
argsTypeList.push_back(getMlirType(input.get()));
}
for (auto output : kernel_outputs)
for (auto output : kernelOutput)
{
result_type_list.push_back(get_mlir_type(output.get()));
resultTypeList.push_back(getMlirType(output.get()));
}
auto func_type = mlir::FunctionType::get(args_type_list, result_type_list, &m_context);
auto function = mlir::FuncOp::create(mlir::UnknownLoc::get(&m_context), "main", func_type);
auto funcType = mlir::FunctionType::get(argsTypeList, resultTypeList, &m_context);
auto function = mlir::FuncOp::create(mlir::UnknownLoc::get(&m_context), "main", funcType);
function.addEntryBlock();
// populate Tensor->Value maps
int i = 0;
for (auto input : kernel_inputs)
for (auto input : kernelInputs)
{
mlir::Value* arg = function.getArgument(i);
TensorInfo tensor_info{arg};
m_tensor_to_value_map.insert(
TensorToInfo(input->get_output_tensor_ptr().get(), tensor_info));
TensorInfo tensorInfo{arg};
m_tensorToValueMap.insert(TensorToInfo(input->get_output_tensor_ptr().get(), tensorInfo));
i++;
}
// create builder
m_builder = std::unique_ptr<mlir::OpBuilder>(new mlir::OpBuilder(function.getBody()));
build_ng_dialect();
buildNgDialect();
m_module->push_back(function);
if (failed(m_module->verify()))
{
NGRAPH_CHECK(false, "Invalid module after lowering to NG dialect");
}
dump_mlir_module("nGraph Dialect Construction");
dumpMlirModule("nGraph Dialect Construction");
}
template <typename T>
void MLIRCompiler::get_mlir_shape(T ng_shape, llvm::SmallVectorImpl<int64_t>& mlir_shape)
void MLIRCompiler::getMlirShape(T ngShape, llvm::SmallVectorImpl<int64_t>& mlirShape)
{
for (auto dim : ng_shape)
for (auto dim : ngShape)
{
mlir_shape.push_back(dim);
mlirShape.push_back(dim);
}
}
template <typename T>
mlir::ArrayAttr MLIRCompiler::get_shape_as_attr(T ng_shape)
mlir::ArrayAttr MLIRCompiler::getShapeAsAttr(T ngShape)
{
SmallVector<int64_t, 4> mlir_shape;
get_mlir_shape(ng_shape, mlir_shape);
return m_builder->getI64ArrayAttr(mlir_shape);
SmallVector<int64_t, 4> mlirShape;
getMlirShape(ngShape, mlirShape);
return m_builder->getI64ArrayAttr(mlirShape);
}
// Converts an nGraph Tensor into an MLIR tensor type, including the conversion of the Tensor's
// element type.
mlir::Type MLIRCompiler::get_mlir_type(const descriptor::Tensor* tensor)
mlir::Type MLIRCompiler::getMlirType(const descriptor::Tensor* tensor)
{
llvm::SmallVector<int64_t, 4> mlir_shape;
get_mlir_shape(tensor->get_shape(), mlir_shape);
return mlir::NGTensorType::get(
&m_context, get_mlir_type(tensor->get_element_type()), mlir_shape);
llvm::SmallVector<int64_t, 4> mlirShape;
getMlirShape(tensor->get_shape(), mlirShape);
return mlir::NGTensorType::get(&m_context, getMlirType(tensor->get_element_type()), mlirShape);
}
// Converts an nGraph element type into an MLIR type.
mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type)
mlir::Type MLIRCompiler::getMlirType(const element::Type& type)
{
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic push
......@@ -308,31 +306,31 @@ mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type)
#endif
}
mlir::Type MLIRCompiler::get_mlir_type(const ngraph::Node* node)
mlir::Type MLIRCompiler::getMlirType(const ngraph::Node* node)
{
descriptor::Tensor* out_tensor = node->get_output_tensor_ptr().get();
return get_mlir_type(out_tensor);
descriptor::Tensor* outTensor = node->get_output_tensor_ptr().get();
return getMlirType(outTensor);
}
void MLIRCompiler::update_tensor_value(descriptor::Tensor* tensor, mlir::Value* value)
void MLIRCompiler::updateTensorValue(descriptor::Tensor* tensor, mlir::Value* value)
{
NGRAPH_CHECK(m_tensor_to_value_map.find(tensor) == m_tensor_to_value_map.end(),
NGRAPH_CHECK(m_tensorToValueMap.find(tensor) == m_tensorToValueMap.end(),
"tensor value already defined");
TensorInfo tensor_info{value};
m_tensor_to_value_map.insert(TensorToInfo(tensor, tensor_info));
TensorInfo tensorInfo{value};
m_tensorToValueMap.insert(TensorToInfo(tensor, tensorInfo));
}
MLIRCompiler::TensorInfo MLIRCompiler::get_tensor_value(descriptor::Tensor* tensor)
MLIRCompiler::TensorInfo MLIRCompiler::getTensorValue(descriptor::Tensor* tensor)
{
auto it = m_tensor_to_value_map.find(tensor);
auto it = m_tensorToValueMap.find(tensor);
NGRAPH_CHECK(it != m_tensor_to_value_map.end(), "Undefined tensor");
NGRAPH_CHECK(it != m_tensorToValueMap.end(), "Undefined tensor");
return it->second;
}
// Lowers nGraph dialect all the way to LLVM module.
void MLIRCompiler::lower_ng_dialect()
void MLIRCompiler::lowerNgDialect()
{
// Lower NG dialect to Affine
mlir::PassManager pm(&m_context);
......@@ -354,27 +352,27 @@ void MLIRCompiler::lower_ng_dialect()
NGRAPH_CHECK(m_module, "MLIR module is not ready.");
// Lower Standard dialect to LLVM dialect.
mlir::LLVMTypeConverter llvm_converter(&m_context);
mlir::LLVMTypeConverter llvmConverter(&m_context);
mlir::OwningRewritePatternList patterns;
mlir::populateLoopToStdConversionPatterns(patterns, &m_context);
mlir::populateStdToLLVMConversionPatterns(llvm_converter, patterns);
mlir::populateStdToLLVMConversionPatterns(llvmConverter, patterns);
mlir::ConversionTarget target(m_context);
target.addLegalDialect<mlir::LLVM::LLVMDialect>();
target.addLegalOp<mlir::ModuleOp, mlir::ModuleTerminatorOp>();
target.addDynamicallyLegalOp<mlir::FuncOp>(
[&](mlir::FuncOp op) { return llvm_converter.isSignatureLegal(op.getType()); });
auto result = applyFullConversion(*m_module, target, std::move(patterns), &llvm_converter);
[&](mlir::FuncOp op) { return llvmConverter.isSignatureLegal(op.getType()); });
auto result = applyFullConversion(*m_module, target, std::move(patterns), &llvmConverter);
NGRAPH_CHECK(succeeded(result), "Standard to LLVM dialect conversion failed");
dump_mlir_module("LLVM-IR Dialect Conversion");
dumpMlirModule("LLVM-IR Dialect Conversion");
// Create an MLIR execution engine. We use a null MLIR pass manager for now to make sure we
// don't run MLIR passes that were already run. We also pass a default transformer created with
// the default or user-provided optimization level.
auto llvm_transformer =
mlir::makeOptimizingTransformer(mlir_opt_level, /*sizeLevel=*/0, target_machine.get());
auto maybeEngine = mlir::ExecutionEngine::create(m_module.get(), llvm_transformer);
auto llvmTransformer =
mlir::makeOptimizingTransformer(mlirOptLevel, /*sizeLevel=*/0, targetMachine.get());
auto maybeEngine = mlir::ExecutionEngine::create(m_module.get(), llvmTransformer);
NGRAPH_CHECK(maybeEngine, "failed to construct an execution engine");
m_engine = std::move(maybeEngine.get());
}
......@@ -417,13 +415,13 @@ void MLIRCompiler::optimize()
// LLVM TTI infra while MLIR does not have target model.
llvm::LLVMContext llvmContext;
auto module = std::unique_ptr<llvm::Module>(new llvm::Module("test", llvmContext));
module->setDataLayout(target_machine->createDataLayout());
module->setDataLayout(targetMachine->createDataLayout());
auto ttiSetupFunc = llvm::cast<llvm::Function>(
module
->getOrInsertFunction("__ngraph_tti_setup",
llvm::FunctionType::get(llvm::Type::getVoidTy(llvmContext), {}))
.getCallee());
auto targetInfo = target_machine->getTargetTransformInfo(*ttiSetupFunc);
auto targetInfo = targetMachine->getTargetTransformInfo(*ttiSetupFunc);
// Populate pass manager with affine dialect optimizations.
mlir::PassManager pm(&m_context);
......@@ -461,14 +459,14 @@ void MLIRCompiler::optimize()
// MLIR builders
#define TI(x) std::type_index(typeid(x))
void MLIRCompiler::build_ng_dialect()
void MLIRCompiler::buildNgDialect()
{
const NodeVector& sub_graph = m_compiled_kernel->get_node_list();
const NodeVector& subGraph = m_compiledKernel->get_node_list();
for (auto np : sub_graph)
for (auto np : subGraph)
{
auto it = op_dispatcher.find(TI(*np));
if (it == op_dispatcher.end())
auto it = opDispatcher.find(TI(*np));
if (it == opDispatcher.end())
{
throw unsupported_op{std::string{"The MLIR backend doesn't currently implement the '"} +
np->description() + "' operation"};
......@@ -484,12 +482,12 @@ void MLIRCompiler::build_ng_dialect()
mlir::Value* result = op->getResult(i);
if (result)
{
update_tensor_value(np->get_output_tensor_ptr(i).get(), result);
updateTensorValue(np->get_output_tensor_ptr(i).get(), result);
}
}
}
}
create_return();
createReturn();
}
namespace ngraph
......@@ -501,118 +499,117 @@ namespace ngraph
template <>
mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Add)
{
return compiler.create_generic_op<mlir::NGAddOp>(ng_node);
return compiler.createGenericOp<mlir::NGAddOp>(ngNode);
}
template <>
mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Subtract)
{
return compiler.create_generic_op<mlir::NGSubOp>(ng_node);
return compiler.createGenericOp<mlir::NGSubOp>(ngNode);
}
template <>
mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Multiply)
{
return compiler.create_generic_op<mlir::NGMulOp>(ng_node);
return compiler.createGenericOp<mlir::NGMulOp>(ngNode);
}
template <>
mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Divide)
{
return compiler.create_generic_op<mlir::NGDivOp>(ng_node);
return compiler.createGenericOp<mlir::NGDivOp>(ngNode);
}
template <>
mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Greater)
{
return compiler.create_generic_op<mlir::NGGreaterOp>(ng_node);
return compiler.createGenericOp<mlir::NGGreaterOp>(ngNode);
}
template <>
mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Less)
{
return compiler.create_generic_op<mlir::NGLessOp>(ng_node);
return compiler.createGenericOp<mlir::NGLessOp>(ngNode);
}
template <>
mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Maximum)
{
return compiler.create_generic_op<mlir::NGMaxOp>(ng_node);
return compiler.createGenericOp<mlir::NGMaxOp>(ngNode);
}
template <>
mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Minimum)
{
return compiler.create_generic_op<mlir::NGMinOp>(ng_node);
return compiler.createGenericOp<mlir::NGMinOp>(ngNode);
}
template <>
mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMax)
{
return compiler.create_index_reduction<mlir::NGArgMaxRedOp>(ng_node);
return compiler.createIndexReduction<mlir::NGArgMaxRedOp>(ngNode);
}
template <>
mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMin)
{
return compiler.create_index_reduction<mlir::NGArgMinRedOp>(ng_node);
return compiler.createIndexReduction<mlir::NGArgMinRedOp>(ngNode);
}
template <>
mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Dot)
{
return compiler.create_generic_op<mlir::NGDotOp>(ng_node);
return compiler.createGenericOp<mlir::NGDotOp>(ngNode);
}
template <>
mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Concat)
{
auto ng_node_concat = static_cast<const ngraph::op::Concat*>(ng_node);
auto op = compiler.create_generic_op<mlir::NGConcatOp>(ng_node);
op->setAttr("concatenation_axis",
compiler.m_builder->getI64IntegerAttr(
ng_node_concat->get_concatenation_axis()));
auto concat = static_cast<const ngraph::op::Concat*>(ngNode);
auto op = compiler.createGenericOp<mlir::NGConcatOp>(ngNode);
op->setAttr(
"concatenation_axis",
compiler.m_builder->getI64IntegerAttr(concat->get_concatenation_axis()));
return op;
}
template <>
mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Gather)
{
auto ng_node_gather = static_cast<const ngraph::op::Gather*>(ng_node);
auto op = compiler.create_generic_op<mlir::NGGatherOp>(ng_node);
op->setAttr("axis",
compiler.m_builder->getI64IntegerAttr(ng_node_gather->get_axis()));
auto gather = static_cast<const ngraph::op::Gather*>(ngNode);
auto op = compiler.createGenericOp<mlir::NGGatherOp>(ngNode);
op->setAttr("axis", compiler.m_builder->getI64IntegerAttr(gather->get_axis()));
return op;
}
template <>
mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Relu)
{
return compiler.create_generic_op<mlir::NGReluOp>(ng_node);
return compiler.createGenericOp<mlir::NGReluOp>(ngNode);
}
template <>
mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Negative)
{
return compiler.create_generic_op<mlir::NGNegOp>(ng_node);
return compiler.createGenericOp<mlir::NGNegOp>(ngNode);
}
template <>
mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Convolution)
{
mlir::Operation* op = compiler.create_generic_op<mlir::NGConvolutionOp>(ng_node);
auto conv_node = static_cast<const ngraph::op::Convolution*>(ng_node);
auto conv_op = llvm::cast<mlir::NGConvolutionOp>(op);
mlir::Operation* op = compiler.createGenericOp<mlir::NGConvolutionOp>(ngNode);
auto convNode = static_cast<const ngraph::op::Convolution*>(ngNode);
auto convOp = llvm::cast<mlir::NGConvolutionOp>(op);
mlir::ArrayAttr attr =
compiler.get_shape_as_attr(conv_node->get_window_movement_strides());
conv_op.setStrides(attr);
compiler.getShapeAsAttr(convNode->get_window_movement_strides());
convOp.setStrides(attr);
attr = compiler.get_shape_as_attr(conv_node->get_padding_below());
conv_op.setPadBelow(attr);
attr = compiler.getShapeAsAttr(convNode->get_padding_below());
convOp.setPadBelow(attr);
attr = compiler.get_shape_as_attr(conv_node->get_padding_above());
conv_op.setPadAbove(attr);
attr = compiler.getShapeAsAttr(convNode->get_padding_above());
convOp.setPadAbove(attr);
return op;
}
}
......@@ -620,58 +617,58 @@ namespace ngraph
}
template <typename Op>
mlir::Operation* MLIRCompiler::create_generic_op(const ngraph::Node* ng_node)
mlir::Operation* MLIRCompiler::createGenericOp(const ngraph::Node* ngNode)
{
std::vector<mlir::Value*> arg_values;
std::vector<mlir::Type> res_types;
for (auto& arg : ng_node->get_arguments())
std::vector<mlir::Value*> argValues;
std::vector<mlir::Type> resTypes;
for (auto& arg : ngNode->get_arguments())
{
auto arg_tensor = arg->get_output_tensor_ptr();
auto arg_v = get_tensor_value(arg_tensor.get()).m_value;
arg_values.push_back(arg_v);
auto argTensor = arg->get_output_tensor_ptr();
auto argv = getTensorValue(argTensor.get()).m_value;
argValues.push_back(argv);
}
for (auto& output : ng_node->outputs())
for (auto& output : ngNode->outputs())
{
res_types.push_back(get_mlir_type(output.get_tensor_ptr().get()));
resTypes.push_back(getMlirType(output.get_tensor_ptr().get()));
}
return (m_builder->create<Op,
ArrayRef<mlir::Type>,
ArrayRef<mlir::Value*>,
ArrayRef<mlir::NamedAttribute>>(
mlir::UnknownLoc::get(&m_context), res_types, arg_values, {/* no attrs */}))
mlir::UnknownLoc::get(&m_context), resTypes, argValues, {/* no attrs */}))
.getOperation();
}
const MLIRCompiler::MLIRCompOpMap MLIRCompiler::op_dispatcher{
#define MLIR_OP(OP) {TI(ngraph::op::OP), &MLIRCompiler::create_op<ngraph::op::OP>},
const MLIRCompiler::MLIRCompOpMap MLIRCompiler::opDispatcher{
#define MLIR_OP(OP) {TI(ngraph::op::OP), &MLIRCompiler::createOp<ngraph::op::OP>},
#include "ops_supported.inc"
};
void MLIRCompiler::create_return()
void MLIRCompiler::createReturn()
{
std::vector<mlir::Value*> value_list;
for (auto output : m_compiled_kernel->get_kernel_outputs())
std::vector<mlir::Value*> valueList;
for (auto output : m_compiledKernel->get_kernel_outputs())
{
value_list.push_back(get_tensor_value(output->get_output_tensor_ptr().get()).m_value);
valueList.push_back(getTensorValue(output->get_output_tensor_ptr().get()).m_value);
}
m_builder->create<mlir::NGReturnOp>(mlir::UnknownLoc::get(&m_context), value_list);
m_builder->create<mlir::NGReturnOp>(mlir::UnknownLoc::get(&m_context), valueList);
}
template <typename RedOp>
mlir::Operation* MLIRCompiler::create_index_reduction(const ngraph::Node* ng_node)
mlir::Operation* MLIRCompiler::createIndexReduction(const ngraph::Node* ngNode)
{
auto* idx_red = static_cast<const ngraph::op::util::IndexReduction*>(ng_node);
auto op = create_generic_op<RedOp>(ng_node);
mlir::ArrayAttr red_axes_attr =
m_builder->getI64ArrayAttr({(int64_t)idx_red->get_reduction_axis()});
op->setAttr("axes", red_axes_attr);
auto* idxRed = static_cast<const ngraph::op::util::IndexReduction*>(ngNode);
auto op = createGenericOp<RedOp>(ngNode);
mlir::ArrayAttr redAxesAttr =
m_builder->getI64ArrayAttr({(int64_t)idxRed->get_reduction_axis()});
op->setAttr("axes", redAxesAttr);
return op;
}
// Binds MLIR function arguments to the proper values. This includes externally allocated tensors
// helpers to be used inside the function.
void MLIRCompiler::bind_arguments(std::vector<void*>& external_tensors)
void MLIRCompiler::bindArguments(std::vector<void*>& externalTensors)
{
NGRAPH_CHECK(m_module, "MLIR module is not ready.");
......@@ -679,29 +676,29 @@ void MLIRCompiler::bind_arguments(std::vector<void*>& external_tensors)
NGRAPH_CHECK(func && !func.getBlocks().empty(), "Function not found");
// Set external arguments
NGRAPH_CHECK(m_compiled_kernel, "No compiled kernel set for compiler");
NGRAPH_CHECK((m_compiled_kernel->get_arguments().size() +
m_compiled_kernel->get_kernel_outputs().size()) == external_tensors.size(),
NGRAPH_CHECK(m_compiledKernel, "No compiled kernel set for compiler");
NGRAPH_CHECK((m_compiledKernel->get_arguments().size() +
m_compiledKernel->get_kernel_outputs().size()) == externalTensors.size(),
"Number of arguments and outputs doesn't match number of tensors");
m_external_tensors = &external_tensors;
m_externalTensors = &externalTensors;
// Create list with a type-erased double pointer for each invocation arguments.
// We currently use 'allocateMemRefArguments', which creates a
// We currently use 'allocateMemrefArgs', which creates a
// SmallVector<StaticFloatMemref*>. StaticFloatMemref is just a struct with the
// actual pointer to the data.
// create MemRef args
auto expected_arguments = allocate_memref_args();
NGRAPH_CHECK(expected_arguments.size(), "Arguments can't be created");
m_invoke_args = std::move(expected_arguments);
auto expectedArguments = allocateMemrefArgs();
NGRAPH_CHECK(expectedArguments.size(), "Arguments can't be created");
m_invokeArgs = std::move(expectedArguments);
NGRAPH_CHECK(m_invoke_args.size() == m_external_tensors->size(),
NGRAPH_CHECK(m_invokeArgs.size() == m_externalTensors->size(),
"Number of external tensors doesn't match number of function arguments");
// Assign external tensor pointers to invocation arguments.
for (size_t i = 0, num_args = m_invoke_args.size(); i < num_args; ++i)
for (size_t i = 0, numArgs = m_invokeArgs.size(); i < numArgs; ++i)
{
((mlir::StaticFloatMemRef*)m_invoke_args[i])->data = (float*)(*m_external_tensors)[i];
((mlir::StaticFloatMemRef*)m_invokeArgs[i])->data = (float*)(*m_externalTensors)[i];
}
}
......@@ -712,14 +709,14 @@ void MLIRCompiler::execute()
// uniformity reasons, it takes a list of type-erased pointers to arguments.
// Please, note that 'invoke' method is overloaded with a parameter pack version.
// Make sure the MutableArrayRef version is invoked.
auto invocationResult = m_engine->invoke("main", llvm::MutableArrayRef<void*>(m_invoke_args));
auto invocationResult = m_engine->invoke("main", llvm::MutableArrayRef<void*>(m_invokeArgs));
NGRAPH_CHECK(!invocationResult, "JIT invocation of 'main' failed\n");
}
void MLIRCompiler::cleanup()
{
// Free void double pointer arguments without freeing external tensor data.
for (auto* arg : m_invoke_args)
for (auto* arg : m_invokeArgs)
{
free(arg);
}
......@@ -731,18 +728,18 @@ void MLIRCompiler::cleanup()
}
}
SmallVector<void*, 8> MLIRCompiler::allocate_memref_args()
SmallVector<void*, 8> MLIRCompiler::allocateMemrefArgs()
{
SmallVector<void*, 8> args;
for (auto i = 0; i < m_external_tensors->size(); i++)
for (auto i = 0; i < m_externalTensors->size(); i++)
{
auto descriptor = allocate_memref_descriptor();
auto descriptor = allocateMemrefDescriptor();
args.push_back(descriptor);
}
return args;
}
mlir::StaticFloatMemRef* MLIRCompiler::allocate_memref_descriptor()
mlir::StaticFloatMemRef* MLIRCompiler::allocateMemrefDescriptor()
{
// We only use StaticFloatMemRef because that's what MLIR currently offers.
// We should expand this with different types and dynamic MemRefs
......@@ -753,7 +750,7 @@ mlir::StaticFloatMemRef* MLIRCompiler::allocate_memref_descriptor()
return descriptor;
}
void MLIRCompiler::dump_mlir_module(const std::string msg)
void MLIRCompiler::dumpMlirModule(const std::string msg)
{
if (clPrintIRAfterAll)
{
......
......@@ -14,8 +14,8 @@
// limitations under the License.
//*****************************************************************************
// NOTE: This file follows nGraph format style and naming convention since it
// exposes a public API to the rest of nGraph codebase.
// NOTE: This file follows nGraph format style.
// Follows nGraph naming convention for public APIs only, else MLIR naming convention.
#pragma once
......@@ -69,7 +69,7 @@ namespace ngraph
using TypeList = llvm::SmallVector<mlir::Type, 4>;
MLIRCompiler(const ngraph::op::CompiledKernel* compiled_kernel)
: m_compiled_kernel(compiled_kernel)
: m_compiledKernel(compiled_kernel)
{
}
......@@ -77,7 +77,7 @@ namespace ngraph
void compile();
/// Executes a pre-compiled subgraph
void run(std::vector<void*>& external_tensors);
void run(std::vector<void*>& externalTensors);
private:
struct TensorInfo
......@@ -87,66 +87,65 @@ namespace ngraph
};
private:
void build_ng_dialect_module();
void lower_ng_dialect();
void buildNgDialectModule();
void lowerNgDialect();
void optimize();
void bind_arguments(std::vector<void*>& external_tensors);
void bindArguments(std::vector<void*>& externalTensors);
void execute();
void cleanup();
mlir::Type get_mlir_type(const descriptor::Tensor* tensor);
mlir::Type get_mlir_type(const element::Type& type);
mlir::Type get_mlir_type(const ngraph::Node* node);
mlir::Type getMlirType(const descriptor::Tensor* tensor);
mlir::Type getMlirType(const element::Type& type);
mlir::Type getMlirType(const ngraph::Node* node);
TensorInfo get_tensor_value(descriptor::Tensor* tensor);
void update_tensor_value(descriptor::Tensor* tensor, mlir::Value* value);
TensorInfo getTensorValue(descriptor::Tensor* tensor);
void updateTensorValue(descriptor::Tensor* tensor, mlir::Value* value);
void build_ng_dialect();
void buildNgDialect();
template <typename Op>
static mlir::Operation* create_op(MLIRCompiler& compiler,
const ngraph::Node* ng_node)
static mlir::Operation* createOp(MLIRCompiler& compiler, const ngraph::Node* ngNode)
{
throw std::runtime_error("Unimplemented op '" + ng_node->description() +
throw std::runtime_error("Unimplemented op '" + ngNode->description() +
"' in MLIR Compiler");
}
// Generic op lowerer to ng dialect.
// Simply maps ngraph tensors to values and generate an OP. No op-specific logic.
template <typename Op>
mlir::Operation* create_generic_op(const ngraph::Node* ng_node);
mlir::Operation* createGenericOp(const ngraph::Node* ngNode);
template <typename RedOp>
mlir::Operation* create_index_reduction(const ngraph::Node* ng_node);
mlir::Operation* createIndexReduction(const ngraph::Node* ngNode);
void create_return();
void createReturn();
/// Helper to create memref arguments for MLIR function signature
llvm::SmallVector<void*, 8> allocate_memref_args();
llvm::SmallVector<void*, 8> allocateMemrefArgs();
/// Helper to allocate a mem ref object. Handles static shapes only for now.
mlir::StaticFloatMemRef* allocate_memref_descriptor();
mlir::StaticFloatMemRef* allocateMemrefDescriptor();
/// Helper to dump MLIR module into llvm::dbgs prepended by the message \p msg.
void dump_mlir_module(const std::string msg);
void dumpMlirModule(const std::string msg);
/// Converts nGraph shape-like types \p ng_shape to MLIR shape \p mlir_shape.
template <typename T>
void get_mlir_shape(T ng_shape, llvm::SmallVectorImpl<int64_t>& mlir_shape);
void getMlirShape(T ngShape, llvm::SmallVectorImpl<int64_t>& mlirShape);
/// Converts an ngraph shape to an I64 array attribute
template <typename T>
mlir::ArrayAttr get_shape_as_attr(T ng_shape);
mlir::ArrayAttr getShapeAsAttr(T ngShape);
private:
// Sub-graph to be compiled and executed with MLIR.
const ngraph::op::CompiledKernel* m_compiled_kernel;
const ngraph::op::CompiledKernel* m_compiledKernel;
// Pointers to externally allocated memory for sub-graph's input and output tensors.
std::vector<void*>* m_external_tensors;
std::vector<void*>* m_externalTensors;
// Arguments for the MLIR function generated for the nGraph sub-graph.
llvm::SmallVector<void*, 8> m_invoke_args;
llvm::SmallVector<void*, 8> m_invokeArgs;
// MLIR context that holds all the MLIR information related to the sub-graph
// compilation.
......@@ -164,11 +163,11 @@ namespace ngraph
// Maps tensor to the value it represents in the IR
// use for MLIR dialect gen
TensorToInfoMap m_tensor_to_value_map;
static const MLIRCompOpMap op_dispatcher;
TensorToInfoMap m_tensorToValueMap;
static const MLIRCompOpMap opDispatcher;
// Optimization level used by MLIR and LLVM compilers.
static unsigned mlir_opt_level;
static unsigned mlirOptLevel;
// LLVM target machine to be used by this MLIR compiler instance to retrieve
// information about target features.
......@@ -178,7 +177,7 @@ namespace ngraph
// machine or configuration flags.
// TODO: Move target machine to external nGraph backend when multiple backends start
// to use MLIR.
static std::unique_ptr<llvm::TargetMachine> target_machine;
static std::unique_ptr<llvm::TargetMachine> targetMachine;
};
}
}
......
......@@ -51,7 +51,9 @@ static mlir::LogicalResult verifyCompatibleOperandsAndResults(T* op, bool checkR
for (auto operand : opr->getOperands())
{
if (i == 0)
{
continue;
}
mlir::Type t = operand->getType();
mlir::NGTensorType opType = t.cast<NGTensorType>();
if (!opType.isCompatible(opType0))
......
......@@ -82,10 +82,14 @@ bool NGTensorType::isCompatible(NGTensorType& other) const
{
// Exact same tensor
if (this == &other)
{
return true;
}
// different tensors, check if of same element type and compatible shapes
if (getElementType() != other.getElementType())
{
return false;
}
// TODO: Handle dynamic ranks
// MLIR MemRefType doesn't seem to support it at the moment.
return isCompatibleShape(other);
......@@ -97,7 +101,9 @@ bool NGTensorType::isCompatibleShape(NGTensorType& other) const
auto otherShape = other.getShape();
if (shape.size() != otherShape.size())
{
return false;
}
for (auto i = 0; i < shape.size(); i++)
{
......@@ -105,7 +111,9 @@ bool NGTensorType::isCompatibleShape(NGTensorType& other) const
NGRAPH_CHECK(otherShape[i] >= -1, "Invalid tensor shape", otherShape[i]);
if (shape[i] == -1 || otherShape[i] == -1 || shape[i] == otherShape[i])
{
continue;
}
return false;
}
return true;
......
......@@ -104,13 +104,19 @@ namespace
// Convert the original function arguments.
TypeConverter::SignatureConversion result(type.getNumInputs());
for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
{
if (failed(converter.convertSignatureArg(i, type.getInput(i), result)))
{
return matchFailure();
}
}
// Convert the original function results.
SmallVector<Type, 4> convertedResults;
if (failed(converter.convertTypes(type.getResults(), convertedResults)))
{
return matchFailure();
}
// Add result types as input args without mapping
result.addInputs(convertedResults);
......@@ -139,16 +145,16 @@ namespace
DialectLoweringPass& pass);
template <typename OP>
void lower_binary_elementwise(Operation* op,
ArrayRef<Value*> operands,
PatternRewriter& rewriter,
DialectLoweringPass& pass);
void lowerBinaryElementwise(Operation* op,
ArrayRef<Value*> operands,
PatternRewriter& rewriter,
DialectLoweringPass& pass);
template <typename OP>
void lower_unary_elementwise(Operation* op,
ArrayRef<Value*> operands,
PatternRewriter& rewriter,
DialectLoweringPass& pass);
void lowerUnaryElementwise(Operation* op,
ArrayRef<Value*> operands,
PatternRewriter& rewriter,
DialectLoweringPass& pass);
ValueHandle createZeroConstant(mlir::Type type);
......@@ -376,49 +382,49 @@ namespace
REWRITER(NGAddOp)
{
lower_binary_elementwise<mlir::NGAddOp>(op, operands, rewriter, pass);
lowerBinaryElementwise<mlir::NGAddOp>(op, operands, rewriter, pass);
return matchSuccess();
}
REWRITER(NGSubOp)
{
lower_binary_elementwise<mlir::NGSubOp>(op, operands, rewriter, pass);
lowerBinaryElementwise<mlir::NGSubOp>(op, operands, rewriter, pass);
return matchSuccess();
}
REWRITER(NGMulOp)
{
lower_binary_elementwise<mlir::NGMulOp>(op, operands, rewriter, pass);
lowerBinaryElementwise<mlir::NGMulOp>(op, operands, rewriter, pass);
return matchSuccess();
}
REWRITER(NGDivOp)
{
lower_binary_elementwise<mlir::NGDivOp>(op, operands, rewriter, pass);
lowerBinaryElementwise<mlir::NGDivOp>(op, operands, rewriter, pass);
return matchSuccess();
}
REWRITER(NGGreaterOp)
{
lower_binary_elementwise<mlir::NGGreaterOp>(op, operands, rewriter, pass);
lowerBinaryElementwise<mlir::NGGreaterOp>(op, operands, rewriter, pass);
return matchSuccess();
}
REWRITER(NGLessOp)
{
lower_binary_elementwise<mlir::NGLessOp>(op, operands, rewriter, pass);
lowerBinaryElementwise<mlir::NGLessOp>(op, operands, rewriter, pass);
return matchSuccess();
}
REWRITER(NGMaxOp)
{
lower_binary_elementwise<mlir::NGMaxOp>(op, operands, rewriter, pass);
lowerBinaryElementwise<mlir::NGMaxOp>(op, operands, rewriter, pass);
return matchSuccess();
}
REWRITER(NGMinOp)
{
lower_binary_elementwise<mlir::NGMinOp>(op, operands, rewriter, pass);
lowerBinaryElementwise<mlir::NGMinOp>(op, operands, rewriter, pass);
return matchSuccess();
}
......@@ -477,7 +483,7 @@ namespace
// Negative
REWRITER(NGNegOp)
{
lower_unary_elementwise<mlir::NGNegOp>(op, operands, rewriter, pass);
lowerUnaryElementwise<mlir::NGNegOp>(op, operands, rewriter, pass);
return matchSuccess();
}
......@@ -950,10 +956,10 @@ namespace
#undef REWRITER
/// End of pattern matchers
template <typename OP>
void lower_unary_elementwise(Operation* op,
ArrayRef<Value*> operands,
PatternRewriter& rewriter,
DialectLoweringPass& pass)
void lowerUnaryElementwise(Operation* op,
ArrayRef<Value*> operands,
PatternRewriter& rewriter,
DialectLoweringPass& pass)
{
auto loc = cast<OP>(op).getLoc();
......@@ -999,10 +1005,10 @@ namespace
}
template <typename OP>
void lower_binary_elementwise(Operation* op,
ArrayRef<Value*> operands,
PatternRewriter& rewriter,
DialectLoweringPass& pass)
void lowerBinaryElementwise(Operation* op,
ArrayRef<Value*> operands,
PatternRewriter& rewriter,
DialectLoweringPass& pass)
{
auto loc = cast<OP>(op).getLoc();
auto result = pass.buildOutputDefs(op, rewriter)[0];
......@@ -1138,7 +1144,9 @@ namespace
for (auto i = 0; i < vArg.rank(); i++)
{
if (i != axis)
{
nonRedIVs.push_back(allIVs[i]);
}
}
// Load current min index with integer data type and convert it to index data type.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment