Commit 7b3c323b authored by Nagy Mostafa's avatar Nagy Mostafa Committed by omarkanawi

[MLIR] Fix style in compiler and lowerer files (#3564)

* Fix style in compiler and lowerer files

* Fix comment in headers

* Revert "Fix comment in headers"

This reverts commit d52eed4c1bdf371f3cc7d3f601d9d2b1b0c233e8.

* Fix compiler.* header. Fix code style in other files
parent 76e4485b
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
// NOTE: This file follows nGraph format style and naming convention since it // NOTE: This file follows nGraph format style.
// exposes a public API to the rest of nGraph codebase. // Follows nGraph naming convention for public APIs only, else MLIR naming convention.
#include "compiler.hpp" #include "compiler.hpp"
...@@ -116,17 +116,17 @@ static llvm::cl::opt<unsigned> clLoopTilingCacheSize( ...@@ -116,17 +116,17 @@ static llvm::cl::opt<unsigned> clLoopTilingCacheSize(
"-loop-tile-cache-level.")); "-loop-tile-cache-level."));
#define COMPILE_OP_DECL(op_name) \ #define COMPILE_OP_DECL(op_name) \
create_op<op_name>(MLIRCompiler & compiler, const ngraph::Node* ng_node) createOp<op_name>(MLIRCompiler & compiler, const ngraph::Node* ngNode)
// Default optimization level. // Default optimization level.
unsigned MLIRCompiler::mlir_opt_level = 2; unsigned MLIRCompiler::mlirOptLevel = 2;
// Target machine will be properly initialized by `init_mlir`. // Target machine will be properly initialized by `init_mlir`.
std::unique_ptr<llvm::TargetMachine> MLIRCompiler::target_machine; std::unique_ptr<llvm::TargetMachine> MLIRCompiler::targetMachine;
/// Creates target machine for current host. /// Creates target machine for current host.
static llvm::Expected<std::unique_ptr<llvm::TargetMachine>> static llvm::Expected<std::unique_ptr<llvm::TargetMachine>>
createDefaultTargetMachine(unsigned opt_level) createDefaultTargetMachine(unsigned optLevel)
{ {
auto machineBuilder = llvm::orc::JITTargetMachineBuilder::detectHost(); auto machineBuilder = llvm::orc::JITTargetMachineBuilder::detectHost();
if (!machineBuilder) if (!machineBuilder)
...@@ -142,17 +142,17 @@ static llvm::Expected<std::unique_ptr<llvm::TargetMachine>> ...@@ -142,17 +142,17 @@ static llvm::Expected<std::unique_ptr<llvm::TargetMachine>>
// Default, // -O2, -Os // Default, // -O2, -Os
// Aggressive // -O3 // Aggressive // -O3
// }; // };
machineBuilder->setCodeGenOptLevel((llvm::CodeGenOpt::Level)opt_level); machineBuilder->setCodeGenOptLevel((llvm::CodeGenOpt::Level)optLevel);
return machineBuilder->createTargetMachine(); return machineBuilder->createTargetMachine();
} }
void MLIRCompiler::init_mlir() void MLIRCompiler::init_mlir()
{ {
// Mutex to safely initialize MLIR. // Mutex to safely initialize MLIR.
static std::mutex mlir_init_mutex; static std::mutex mlirInitMutex;
static bool initialized = false; static bool initialized = false;
std::unique_lock<std::mutex> lock(mlir_init_mutex); std::unique_lock<std::mutex> lock(mlirInitMutex);
if (!initialized) if (!initialized)
{ {
...@@ -164,18 +164,18 @@ void MLIRCompiler::init_mlir() ...@@ -164,18 +164,18 @@ void MLIRCompiler::init_mlir()
llvm::cl::ParseEnvironmentOptions("ngraph", "NGRAPH_MLIR_OPTIONS", ""); llvm::cl::ParseEnvironmentOptions("ngraph", "NGRAPH_MLIR_OPTIONS", "");
// Override default optimization level with macro value. // Override default optimization level with macro value.
if (char* opt_level_str = std::getenv("NGRAPH_MLIR_OPT_LEVEL")) if (char* optLevelStr = std::getenv("NGRAPH_MLIR_OPT_LEVEL"))
{ {
mlir_opt_level = std::stoi(opt_level_str); mlirOptLevel = std::stoi(optLevelStr);
NGRAPH_CHECK(mlir_opt_level >= 0 && mlir_opt_level <= 3, "Invalid optimization level"); NGRAPH_CHECK(mlirOptLevel >= 0 && mlirOptLevel <= 3, "Invalid optimization level");
} }
// Initialize LLVM targets and target machine for current host. // Initialize LLVM targets and target machine for current host.
llvm::InitializeNativeTarget(); llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter(); llvm::InitializeNativeTargetAsmPrinter();
auto expected_target_machine = createDefaultTargetMachine(mlir_opt_level); auto expectedTargetMachine = createDefaultTargetMachine(mlirOptLevel);
NGRAPH_CHECK(expected_target_machine, "Invalid target machine"); NGRAPH_CHECK(expectedTargetMachine, "Invalid target machine");
target_machine = std::move(*expected_target_machine); targetMachine = std::move(*expectedTargetMachine);
initialized = true; initialized = true;
} }
...@@ -183,97 +183,95 @@ void MLIRCompiler::init_mlir() ...@@ -183,97 +183,95 @@ void MLIRCompiler::init_mlir()
void MLIRCompiler::compile() void MLIRCompiler::compile()
{ {
build_ng_dialect_module(); buildNgDialectModule();
lower_ng_dialect(); lowerNgDialect();
} }
void MLIRCompiler::run(std::vector<void*>& external_tensors) void MLIRCompiler::run(std::vector<void*>& externalTensors)
{ {
bind_arguments(external_tensors); bindArguments(externalTensors);
execute(); execute();
cleanup(); cleanup();
} }
// Creates an MLIR module and function with nGraph dialect ops from the input CompiledKernel. // Creates an MLIR module and function with nGraph dialect ops from the input CompiledKernel.
void MLIRCompiler::build_ng_dialect_module() void MLIRCompiler::buildNgDialectModule()
{ {
// initialize an empty module // initialize an empty module
m_module = mlir::ModuleOp::create(mlir::UnknownLoc::get(&m_context)); m_module = mlir::ModuleOp::create(mlir::UnknownLoc::get(&m_context));
TypeList args_type_list, result_type_list; TypeList argsTypeList, resultTypeList;
// Retrieve input and output tensors. // Retrieve input and output tensors.
const auto& kernel_inputs = m_compiled_kernel->get_arguments(); const auto& kernelInputs = m_compiledKernel->get_arguments();
const auto& kernel_outputs = m_compiled_kernel->get_kernel_outputs(); const auto& kernelOutput = m_compiledKernel->get_kernel_outputs();
NGRAPH_CHECK(kernel_inputs.size() != 0, "Cannot have empty inputs list"); NGRAPH_CHECK(kernelInputs.size() != 0, "Cannot have empty inputs list");
NGRAPH_CHECK(kernel_outputs.size() != 0, "Cannot have empty outputs list"); NGRAPH_CHECK(kernelOutput.size() != 0, "Cannot have empty outputs list");
for (auto input : kernel_inputs) for (auto input : kernelInputs)
{ {
args_type_list.push_back(get_mlir_type(input.get())); argsTypeList.push_back(getMlirType(input.get()));
} }
for (auto output : kernel_outputs) for (auto output : kernelOutput)
{ {
result_type_list.push_back(get_mlir_type(output.get())); resultTypeList.push_back(getMlirType(output.get()));
} }
auto func_type = mlir::FunctionType::get(args_type_list, result_type_list, &m_context); auto funcType = mlir::FunctionType::get(argsTypeList, resultTypeList, &m_context);
auto function = mlir::FuncOp::create(mlir::UnknownLoc::get(&m_context), "main", func_type); auto function = mlir::FuncOp::create(mlir::UnknownLoc::get(&m_context), "main", funcType);
function.addEntryBlock(); function.addEntryBlock();
// populate Tensor->Value maps // populate Tensor->Value maps
int i = 0; int i = 0;
for (auto input : kernel_inputs) for (auto input : kernelInputs)
{ {
mlir::Value* arg = function.getArgument(i); mlir::Value* arg = function.getArgument(i);
TensorInfo tensor_info{arg}; TensorInfo tensorInfo{arg};
m_tensor_to_value_map.insert( m_tensorToValueMap.insert(TensorToInfo(input->get_output_tensor_ptr().get(), tensorInfo));
TensorToInfo(input->get_output_tensor_ptr().get(), tensor_info));
i++; i++;
} }
// create builder // create builder
m_builder = std::unique_ptr<mlir::OpBuilder>(new mlir::OpBuilder(function.getBody())); m_builder = std::unique_ptr<mlir::OpBuilder>(new mlir::OpBuilder(function.getBody()));
build_ng_dialect(); buildNgDialect();
m_module->push_back(function); m_module->push_back(function);
if (failed(m_module->verify())) if (failed(m_module->verify()))
{ {
NGRAPH_CHECK(false, "Invalid module after lowering to NG dialect"); NGRAPH_CHECK(false, "Invalid module after lowering to NG dialect");
} }
dump_mlir_module("nGraph Dialect Construction"); dumpMlirModule("nGraph Dialect Construction");
} }
template <typename T> template <typename T>
void MLIRCompiler::get_mlir_shape(T ng_shape, llvm::SmallVectorImpl<int64_t>& mlir_shape) void MLIRCompiler::getMlirShape(T ngShape, llvm::SmallVectorImpl<int64_t>& mlirShape)
{ {
for (auto dim : ng_shape) for (auto dim : ngShape)
{ {
mlir_shape.push_back(dim); mlirShape.push_back(dim);
} }
} }
template <typename T> template <typename T>
mlir::ArrayAttr MLIRCompiler::get_shape_as_attr(T ng_shape) mlir::ArrayAttr MLIRCompiler::getShapeAsAttr(T ngShape)
{ {
SmallVector<int64_t, 4> mlir_shape; SmallVector<int64_t, 4> mlirShape;
get_mlir_shape(ng_shape, mlir_shape); getMlirShape(ngShape, mlirShape);
return m_builder->getI64ArrayAttr(mlir_shape); return m_builder->getI64ArrayAttr(mlirShape);
} }
// Converts an nGraph Tensor into an MLIR tensor type, including the conversion of the Tensor's // Converts an nGraph Tensor into an MLIR tensor type, including the conversion of the Tensor's
// element type. // element type.
mlir::Type MLIRCompiler::get_mlir_type(const descriptor::Tensor* tensor) mlir::Type MLIRCompiler::getMlirType(const descriptor::Tensor* tensor)
{ {
llvm::SmallVector<int64_t, 4> mlir_shape; llvm::SmallVector<int64_t, 4> mlirShape;
get_mlir_shape(tensor->get_shape(), mlir_shape); getMlirShape(tensor->get_shape(), mlirShape);
return mlir::NGTensorType::get( return mlir::NGTensorType::get(&m_context, getMlirType(tensor->get_element_type()), mlirShape);
&m_context, get_mlir_type(tensor->get_element_type()), mlir_shape);
} }
// Converts an nGraph element type into an MLIR type. // Converts an nGraph element type into an MLIR type.
mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type) mlir::Type MLIRCompiler::getMlirType(const element::Type& type)
{ {
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8) #if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic push #pragma GCC diagnostic push
...@@ -308,31 +306,31 @@ mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type) ...@@ -308,31 +306,31 @@ mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type)
#endif #endif
} }
mlir::Type MLIRCompiler::get_mlir_type(const ngraph::Node* node) mlir::Type MLIRCompiler::getMlirType(const ngraph::Node* node)
{ {
descriptor::Tensor* out_tensor = node->get_output_tensor_ptr().get(); descriptor::Tensor* outTensor = node->get_output_tensor_ptr().get();
return get_mlir_type(out_tensor); return getMlirType(outTensor);
} }
void MLIRCompiler::update_tensor_value(descriptor::Tensor* tensor, mlir::Value* value) void MLIRCompiler::updateTensorValue(descriptor::Tensor* tensor, mlir::Value* value)
{ {
NGRAPH_CHECK(m_tensor_to_value_map.find(tensor) == m_tensor_to_value_map.end(), NGRAPH_CHECK(m_tensorToValueMap.find(tensor) == m_tensorToValueMap.end(),
"tensor value already defined"); "tensor value already defined");
TensorInfo tensor_info{value}; TensorInfo tensorInfo{value};
m_tensor_to_value_map.insert(TensorToInfo(tensor, tensor_info)); m_tensorToValueMap.insert(TensorToInfo(tensor, tensorInfo));
} }
MLIRCompiler::TensorInfo MLIRCompiler::get_tensor_value(descriptor::Tensor* tensor) MLIRCompiler::TensorInfo MLIRCompiler::getTensorValue(descriptor::Tensor* tensor)
{ {
auto it = m_tensor_to_value_map.find(tensor); auto it = m_tensorToValueMap.find(tensor);
NGRAPH_CHECK(it != m_tensor_to_value_map.end(), "Undefined tensor"); NGRAPH_CHECK(it != m_tensorToValueMap.end(), "Undefined tensor");
return it->second; return it->second;
} }
// Lowers nGraph dialect all the way to LLVM module. // Lowers nGraph dialect all the way to LLVM module.
void MLIRCompiler::lower_ng_dialect() void MLIRCompiler::lowerNgDialect()
{ {
// Lower NG dialect to Affine // Lower NG dialect to Affine
mlir::PassManager pm(&m_context); mlir::PassManager pm(&m_context);
...@@ -354,27 +352,27 @@ void MLIRCompiler::lower_ng_dialect() ...@@ -354,27 +352,27 @@ void MLIRCompiler::lower_ng_dialect()
NGRAPH_CHECK(m_module, "MLIR module is not ready."); NGRAPH_CHECK(m_module, "MLIR module is not ready.");
// Lower Standard dialect to LLVM dialect. // Lower Standard dialect to LLVM dialect.
mlir::LLVMTypeConverter llvm_converter(&m_context); mlir::LLVMTypeConverter llvmConverter(&m_context);
mlir::OwningRewritePatternList patterns; mlir::OwningRewritePatternList patterns;
mlir::populateLoopToStdConversionPatterns(patterns, &m_context); mlir::populateLoopToStdConversionPatterns(patterns, &m_context);
mlir::populateStdToLLVMConversionPatterns(llvm_converter, patterns); mlir::populateStdToLLVMConversionPatterns(llvmConverter, patterns);
mlir::ConversionTarget target(m_context); mlir::ConversionTarget target(m_context);
target.addLegalDialect<mlir::LLVM::LLVMDialect>(); target.addLegalDialect<mlir::LLVM::LLVMDialect>();
target.addLegalOp<mlir::ModuleOp, mlir::ModuleTerminatorOp>(); target.addLegalOp<mlir::ModuleOp, mlir::ModuleTerminatorOp>();
target.addDynamicallyLegalOp<mlir::FuncOp>( target.addDynamicallyLegalOp<mlir::FuncOp>(
[&](mlir::FuncOp op) { return llvm_converter.isSignatureLegal(op.getType()); }); [&](mlir::FuncOp op) { return llvmConverter.isSignatureLegal(op.getType()); });
auto result = applyFullConversion(*m_module, target, std::move(patterns), &llvm_converter); auto result = applyFullConversion(*m_module, target, std::move(patterns), &llvmConverter);
NGRAPH_CHECK(succeeded(result), "Standard to LLVM dialect conversion failed"); NGRAPH_CHECK(succeeded(result), "Standard to LLVM dialect conversion failed");
dump_mlir_module("LLVM-IR Dialect Conversion"); dumpMlirModule("LLVM-IR Dialect Conversion");
// Create an MLIR execution engine. We use a null MLIR pass manager for now to make sure we // Create an MLIR execution engine. We use a null MLIR pass manager for now to make sure we
// don't run MLIR passes that were already run. We also pass a default transformer created with // don't run MLIR passes that were already run. We also pass a default transformer created with
// the default or user-provided optimization level. // the default or user-provided optimization level.
auto llvm_transformer = auto llvmTransformer =
mlir::makeOptimizingTransformer(mlir_opt_level, /*sizeLevel=*/0, target_machine.get()); mlir::makeOptimizingTransformer(mlirOptLevel, /*sizeLevel=*/0, targetMachine.get());
auto maybeEngine = mlir::ExecutionEngine::create(m_module.get(), llvm_transformer); auto maybeEngine = mlir::ExecutionEngine::create(m_module.get(), llvmTransformer);
NGRAPH_CHECK(maybeEngine, "failed to construct an execution engine"); NGRAPH_CHECK(maybeEngine, "failed to construct an execution engine");
m_engine = std::move(maybeEngine.get()); m_engine = std::move(maybeEngine.get());
} }
...@@ -417,13 +415,13 @@ void MLIRCompiler::optimize() ...@@ -417,13 +415,13 @@ void MLIRCompiler::optimize()
// LLVM TTI infra while MLIR does not have target model. // LLVM TTI infra while MLIR does not have target model.
llvm::LLVMContext llvmContext; llvm::LLVMContext llvmContext;
auto module = std::unique_ptr<llvm::Module>(new llvm::Module("test", llvmContext)); auto module = std::unique_ptr<llvm::Module>(new llvm::Module("test", llvmContext));
module->setDataLayout(target_machine->createDataLayout()); module->setDataLayout(targetMachine->createDataLayout());
auto ttiSetupFunc = llvm::cast<llvm::Function>( auto ttiSetupFunc = llvm::cast<llvm::Function>(
module module
->getOrInsertFunction("__ngraph_tti_setup", ->getOrInsertFunction("__ngraph_tti_setup",
llvm::FunctionType::get(llvm::Type::getVoidTy(llvmContext), {})) llvm::FunctionType::get(llvm::Type::getVoidTy(llvmContext), {}))
.getCallee()); .getCallee());
auto targetInfo = target_machine->getTargetTransformInfo(*ttiSetupFunc); auto targetInfo = targetMachine->getTargetTransformInfo(*ttiSetupFunc);
// Populate pass manager with affine dialect optimizations. // Populate pass manager with affine dialect optimizations.
mlir::PassManager pm(&m_context); mlir::PassManager pm(&m_context);
...@@ -461,14 +459,14 @@ void MLIRCompiler::optimize() ...@@ -461,14 +459,14 @@ void MLIRCompiler::optimize()
// MLIR builders // MLIR builders
#define TI(x) std::type_index(typeid(x)) #define TI(x) std::type_index(typeid(x))
void MLIRCompiler::build_ng_dialect() void MLIRCompiler::buildNgDialect()
{ {
const NodeVector& sub_graph = m_compiled_kernel->get_node_list(); const NodeVector& subGraph = m_compiledKernel->get_node_list();
for (auto np : sub_graph) for (auto np : subGraph)
{ {
auto it = op_dispatcher.find(TI(*np)); auto it = opDispatcher.find(TI(*np));
if (it == op_dispatcher.end()) if (it == opDispatcher.end())
{ {
throw unsupported_op{std::string{"The MLIR backend doesn't currently implement the '"} + throw unsupported_op{std::string{"The MLIR backend doesn't currently implement the '"} +
np->description() + "' operation"}; np->description() + "' operation"};
...@@ -484,12 +482,12 @@ void MLIRCompiler::build_ng_dialect() ...@@ -484,12 +482,12 @@ void MLIRCompiler::build_ng_dialect()
mlir::Value* result = op->getResult(i); mlir::Value* result = op->getResult(i);
if (result) if (result)
{ {
update_tensor_value(np->get_output_tensor_ptr(i).get(), result); updateTensorValue(np->get_output_tensor_ptr(i).get(), result);
} }
} }
} }
} }
create_return(); createReturn();
} }
namespace ngraph namespace ngraph
...@@ -501,118 +499,117 @@ namespace ngraph ...@@ -501,118 +499,117 @@ namespace ngraph
template <> template <>
mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Add) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Add)
{ {
return compiler.create_generic_op<mlir::NGAddOp>(ng_node); return compiler.createGenericOp<mlir::NGAddOp>(ngNode);
} }
template <> template <>
mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Subtract) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Subtract)
{ {
return compiler.create_generic_op<mlir::NGSubOp>(ng_node); return compiler.createGenericOp<mlir::NGSubOp>(ngNode);
} }
template <> template <>
mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Multiply) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Multiply)
{ {
return compiler.create_generic_op<mlir::NGMulOp>(ng_node); return compiler.createGenericOp<mlir::NGMulOp>(ngNode);
} }
template <> template <>
mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Divide) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Divide)
{ {
return compiler.create_generic_op<mlir::NGDivOp>(ng_node); return compiler.createGenericOp<mlir::NGDivOp>(ngNode);
} }
template <> template <>
mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Greater) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Greater)
{ {
return compiler.create_generic_op<mlir::NGGreaterOp>(ng_node); return compiler.createGenericOp<mlir::NGGreaterOp>(ngNode);
} }
template <> template <>
mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Less) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Less)
{ {
return compiler.create_generic_op<mlir::NGLessOp>(ng_node); return compiler.createGenericOp<mlir::NGLessOp>(ngNode);
} }
template <> template <>
mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Maximum) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Maximum)
{ {
return compiler.create_generic_op<mlir::NGMaxOp>(ng_node); return compiler.createGenericOp<mlir::NGMaxOp>(ngNode);
} }
template <> template <>
mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Minimum) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Minimum)
{ {
return compiler.create_generic_op<mlir::NGMinOp>(ng_node); return compiler.createGenericOp<mlir::NGMinOp>(ngNode);
} }
template <> template <>
mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMax) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMax)
{ {
return compiler.create_index_reduction<mlir::NGArgMaxRedOp>(ng_node); return compiler.createIndexReduction<mlir::NGArgMaxRedOp>(ngNode);
} }
template <> template <>
mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMin) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMin)
{ {
return compiler.create_index_reduction<mlir::NGArgMinRedOp>(ng_node); return compiler.createIndexReduction<mlir::NGArgMinRedOp>(ngNode);
} }
template <> template <>
mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Dot) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Dot)
{ {
return compiler.create_generic_op<mlir::NGDotOp>(ng_node); return compiler.createGenericOp<mlir::NGDotOp>(ngNode);
} }
template <> template <>
mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Concat) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Concat)
{ {
auto ng_node_concat = static_cast<const ngraph::op::Concat*>(ng_node); auto concat = static_cast<const ngraph::op::Concat*>(ngNode);
auto op = compiler.create_generic_op<mlir::NGConcatOp>(ng_node); auto op = compiler.createGenericOp<mlir::NGConcatOp>(ngNode);
op->setAttr("concatenation_axis", op->setAttr(
compiler.m_builder->getI64IntegerAttr( "concatenation_axis",
ng_node_concat->get_concatenation_axis())); compiler.m_builder->getI64IntegerAttr(concat->get_concatenation_axis()));
return op; return op;
} }
template <> template <>
mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Gather) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Gather)
{ {
auto ng_node_gather = static_cast<const ngraph::op::Gather*>(ng_node); auto gather = static_cast<const ngraph::op::Gather*>(ngNode);
auto op = compiler.create_generic_op<mlir::NGGatherOp>(ng_node); auto op = compiler.createGenericOp<mlir::NGGatherOp>(ngNode);
op->setAttr("axis", op->setAttr("axis", compiler.m_builder->getI64IntegerAttr(gather->get_axis()));
compiler.m_builder->getI64IntegerAttr(ng_node_gather->get_axis()));
return op; return op;
} }
template <> template <>
mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Relu) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Relu)
{ {
return compiler.create_generic_op<mlir::NGReluOp>(ng_node); return compiler.createGenericOp<mlir::NGReluOp>(ngNode);
} }
template <> template <>
mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Negative) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Negative)
{ {
return compiler.create_generic_op<mlir::NGNegOp>(ng_node); return compiler.createGenericOp<mlir::NGNegOp>(ngNode);
} }
template <> template <>
mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Convolution) mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Convolution)
{ {
mlir::Operation* op = compiler.create_generic_op<mlir::NGConvolutionOp>(ng_node); mlir::Operation* op = compiler.createGenericOp<mlir::NGConvolutionOp>(ngNode);
auto conv_node = static_cast<const ngraph::op::Convolution*>(ng_node); auto convNode = static_cast<const ngraph::op::Convolution*>(ngNode);
auto conv_op = llvm::cast<mlir::NGConvolutionOp>(op); auto convOp = llvm::cast<mlir::NGConvolutionOp>(op);
mlir::ArrayAttr attr = mlir::ArrayAttr attr =
compiler.get_shape_as_attr(conv_node->get_window_movement_strides()); compiler.getShapeAsAttr(convNode->get_window_movement_strides());
conv_op.setStrides(attr); convOp.setStrides(attr);
attr = compiler.get_shape_as_attr(conv_node->get_padding_below()); attr = compiler.getShapeAsAttr(convNode->get_padding_below());
conv_op.setPadBelow(attr); convOp.setPadBelow(attr);
attr = compiler.get_shape_as_attr(conv_node->get_padding_above()); attr = compiler.getShapeAsAttr(convNode->get_padding_above());
conv_op.setPadAbove(attr); convOp.setPadAbove(attr);
return op; return op;
} }
} }
...@@ -620,58 +617,58 @@ namespace ngraph ...@@ -620,58 +617,58 @@ namespace ngraph
} }
template <typename Op> template <typename Op>
mlir::Operation* MLIRCompiler::create_generic_op(const ngraph::Node* ng_node) mlir::Operation* MLIRCompiler::createGenericOp(const ngraph::Node* ngNode)
{ {
std::vector<mlir::Value*> arg_values; std::vector<mlir::Value*> argValues;
std::vector<mlir::Type> res_types; std::vector<mlir::Type> resTypes;
for (auto& arg : ng_node->get_arguments()) for (auto& arg : ngNode->get_arguments())
{ {
auto arg_tensor = arg->get_output_tensor_ptr(); auto argTensor = arg->get_output_tensor_ptr();
auto arg_v = get_tensor_value(arg_tensor.get()).m_value; auto argv = getTensorValue(argTensor.get()).m_value;
arg_values.push_back(arg_v); argValues.push_back(argv);
} }
for (auto& output : ng_node->outputs()) for (auto& output : ngNode->outputs())
{ {
res_types.push_back(get_mlir_type(output.get_tensor_ptr().get())); resTypes.push_back(getMlirType(output.get_tensor_ptr().get()));
} }
return (m_builder->create<Op, return (m_builder->create<Op,
ArrayRef<mlir::Type>, ArrayRef<mlir::Type>,
ArrayRef<mlir::Value*>, ArrayRef<mlir::Value*>,
ArrayRef<mlir::NamedAttribute>>( ArrayRef<mlir::NamedAttribute>>(
mlir::UnknownLoc::get(&m_context), res_types, arg_values, {/* no attrs */})) mlir::UnknownLoc::get(&m_context), resTypes, argValues, {/* no attrs */}))
.getOperation(); .getOperation();
} }
const MLIRCompiler::MLIRCompOpMap MLIRCompiler::op_dispatcher{ const MLIRCompiler::MLIRCompOpMap MLIRCompiler::opDispatcher{
#define MLIR_OP(OP) {TI(ngraph::op::OP), &MLIRCompiler::create_op<ngraph::op::OP>}, #define MLIR_OP(OP) {TI(ngraph::op::OP), &MLIRCompiler::createOp<ngraph::op::OP>},
#include "ops_supported.inc" #include "ops_supported.inc"
}; };
void MLIRCompiler::create_return() void MLIRCompiler::createReturn()
{ {
std::vector<mlir::Value*> value_list; std::vector<mlir::Value*> valueList;
for (auto output : m_compiled_kernel->get_kernel_outputs()) for (auto output : m_compiledKernel->get_kernel_outputs())
{ {
value_list.push_back(get_tensor_value(output->get_output_tensor_ptr().get()).m_value); valueList.push_back(getTensorValue(output->get_output_tensor_ptr().get()).m_value);
} }
m_builder->create<mlir::NGReturnOp>(mlir::UnknownLoc::get(&m_context), value_list); m_builder->create<mlir::NGReturnOp>(mlir::UnknownLoc::get(&m_context), valueList);
} }
template <typename RedOp> template <typename RedOp>
mlir::Operation* MLIRCompiler::create_index_reduction(const ngraph::Node* ng_node) mlir::Operation* MLIRCompiler::createIndexReduction(const ngraph::Node* ngNode)
{ {
auto* idx_red = static_cast<const ngraph::op::util::IndexReduction*>(ng_node); auto* idxRed = static_cast<const ngraph::op::util::IndexReduction*>(ngNode);
auto op = create_generic_op<RedOp>(ng_node); auto op = createGenericOp<RedOp>(ngNode);
mlir::ArrayAttr red_axes_attr = mlir::ArrayAttr redAxesAttr =
m_builder->getI64ArrayAttr({(int64_t)idx_red->get_reduction_axis()}); m_builder->getI64ArrayAttr({(int64_t)idxRed->get_reduction_axis()});
op->setAttr("axes", red_axes_attr); op->setAttr("axes", redAxesAttr);
return op; return op;
} }
// Binds MLIR function arguments to the proper values. This includes externally allocated tensors // Binds MLIR function arguments to the proper values. This includes externally allocated tensors
// helpers to be used inside the function. // helpers to be used inside the function.
void MLIRCompiler::bind_arguments(std::vector<void*>& external_tensors) void MLIRCompiler::bindArguments(std::vector<void*>& externalTensors)
{ {
NGRAPH_CHECK(m_module, "MLIR module is not ready."); NGRAPH_CHECK(m_module, "MLIR module is not ready.");
...@@ -679,29 +676,29 @@ void MLIRCompiler::bind_arguments(std::vector<void*>& external_tensors) ...@@ -679,29 +676,29 @@ void MLIRCompiler::bind_arguments(std::vector<void*>& external_tensors)
NGRAPH_CHECK(func && !func.getBlocks().empty(), "Function not found"); NGRAPH_CHECK(func && !func.getBlocks().empty(), "Function not found");
// Set external arguments // Set external arguments
NGRAPH_CHECK(m_compiled_kernel, "No compiled kernel set for compiler"); NGRAPH_CHECK(m_compiledKernel, "No compiled kernel set for compiler");
NGRAPH_CHECK((m_compiled_kernel->get_arguments().size() + NGRAPH_CHECK((m_compiledKernel->get_arguments().size() +
m_compiled_kernel->get_kernel_outputs().size()) == external_tensors.size(), m_compiledKernel->get_kernel_outputs().size()) == externalTensors.size(),
"Number of arguments and outputs doesn't match number of tensors"); "Number of arguments and outputs doesn't match number of tensors");
m_external_tensors = &external_tensors; m_externalTensors = &externalTensors;
// Create list with a type-erased double pointer for each invocation arguments. // Create list with a type-erased double pointer for each invocation arguments.
// We currently use 'allocateMemRefArguments', which creates a // We currently use 'allocateMemrefArgs', which creates a
// SmallVector<StaticFloatMemref*>. StaticFloatMemref is just a struct with the // SmallVector<StaticFloatMemref*>. StaticFloatMemref is just a struct with the
// actual pointer to the data. // actual pointer to the data.
// create MemRef args // create MemRef args
auto expected_arguments = allocate_memref_args(); auto expectedArguments = allocateMemrefArgs();
NGRAPH_CHECK(expected_arguments.size(), "Arguments can't be created"); NGRAPH_CHECK(expectedArguments.size(), "Arguments can't be created");
m_invoke_args = std::move(expected_arguments); m_invokeArgs = std::move(expectedArguments);
NGRAPH_CHECK(m_invoke_args.size() == m_external_tensors->size(), NGRAPH_CHECK(m_invokeArgs.size() == m_externalTensors->size(),
"Number of external tensors doesn't match number of function arguments"); "Number of external tensors doesn't match number of function arguments");
// Assign external tensor pointers to invocation arguments. // Assign external tensor pointers to invocation arguments.
for (size_t i = 0, num_args = m_invoke_args.size(); i < num_args; ++i) for (size_t i = 0, numArgs = m_invokeArgs.size(); i < numArgs; ++i)
{ {
((mlir::StaticFloatMemRef*)m_invoke_args[i])->data = (float*)(*m_external_tensors)[i]; ((mlir::StaticFloatMemRef*)m_invokeArgs[i])->data = (float*)(*m_externalTensors)[i];
} }
} }
...@@ -712,14 +709,14 @@ void MLIRCompiler::execute() ...@@ -712,14 +709,14 @@ void MLIRCompiler::execute()
// uniformity reasons, it takes a list of type-erased pointers to arguments. // uniformity reasons, it takes a list of type-erased pointers to arguments.
// Please, note that 'invoke' method is overloaded with a parameter pack version. // Please, note that 'invoke' method is overloaded with a parameter pack version.
// Make sure the MutableArrayRef version is invoked. // Make sure the MutableArrayRef version is invoked.
auto invocationResult = m_engine->invoke("main", llvm::MutableArrayRef<void*>(m_invoke_args)); auto invocationResult = m_engine->invoke("main", llvm::MutableArrayRef<void*>(m_invokeArgs));
NGRAPH_CHECK(!invocationResult, "JIT invocation of 'main' failed\n"); NGRAPH_CHECK(!invocationResult, "JIT invocation of 'main' failed\n");
} }
void MLIRCompiler::cleanup() void MLIRCompiler::cleanup()
{ {
// Free void double pointer arguments without freeing external tensor data. // Free void double pointer arguments without freeing external tensor data.
for (auto* arg : m_invoke_args) for (auto* arg : m_invokeArgs)
{ {
free(arg); free(arg);
} }
...@@ -731,18 +728,18 @@ void MLIRCompiler::cleanup() ...@@ -731,18 +728,18 @@ void MLIRCompiler::cleanup()
} }
} }
SmallVector<void*, 8> MLIRCompiler::allocate_memref_args() SmallVector<void*, 8> MLIRCompiler::allocateMemrefArgs()
{ {
SmallVector<void*, 8> args; SmallVector<void*, 8> args;
for (auto i = 0; i < m_external_tensors->size(); i++) for (auto i = 0; i < m_externalTensors->size(); i++)
{ {
auto descriptor = allocate_memref_descriptor(); auto descriptor = allocateMemrefDescriptor();
args.push_back(descriptor); args.push_back(descriptor);
} }
return args; return args;
} }
mlir::StaticFloatMemRef* MLIRCompiler::allocate_memref_descriptor() mlir::StaticFloatMemRef* MLIRCompiler::allocateMemrefDescriptor()
{ {
// We only use StaticFloatMemRef because that's what MLIR currently offers. // We only use StaticFloatMemRef because that's what MLIR currently offers.
// We should expand this with different types and dynamic MemRefs // We should expand this with different types and dynamic MemRefs
...@@ -753,7 +750,7 @@ mlir::StaticFloatMemRef* MLIRCompiler::allocate_memref_descriptor() ...@@ -753,7 +750,7 @@ mlir::StaticFloatMemRef* MLIRCompiler::allocate_memref_descriptor()
return descriptor; return descriptor;
} }
void MLIRCompiler::dump_mlir_module(const std::string msg) void MLIRCompiler::dumpMlirModule(const std::string msg)
{ {
if (clPrintIRAfterAll) if (clPrintIRAfterAll)
{ {
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
// NOTE: This file follows nGraph format style and naming convention since it // NOTE: This file follows nGraph format style.
// exposes a public API to the rest of nGraph codebase. // Follows nGraph naming convention for public APIs only, else MLIR naming convention.
#pragma once #pragma once
...@@ -69,7 +69,7 @@ namespace ngraph ...@@ -69,7 +69,7 @@ namespace ngraph
using TypeList = llvm::SmallVector<mlir::Type, 4>; using TypeList = llvm::SmallVector<mlir::Type, 4>;
MLIRCompiler(const ngraph::op::CompiledKernel* compiled_kernel) MLIRCompiler(const ngraph::op::CompiledKernel* compiled_kernel)
: m_compiled_kernel(compiled_kernel) : m_compiledKernel(compiled_kernel)
{ {
} }
...@@ -77,7 +77,7 @@ namespace ngraph ...@@ -77,7 +77,7 @@ namespace ngraph
void compile(); void compile();
/// Executes a pre-compiled subgraph /// Executes a pre-compiled subgraph
void run(std::vector<void*>& external_tensors); void run(std::vector<void*>& externalTensors);
private: private:
struct TensorInfo struct TensorInfo
...@@ -87,66 +87,65 @@ namespace ngraph ...@@ -87,66 +87,65 @@ namespace ngraph
}; };
private: private:
void build_ng_dialect_module(); void buildNgDialectModule();
void lower_ng_dialect(); void lowerNgDialect();
void optimize(); void optimize();
void bind_arguments(std::vector<void*>& external_tensors); void bindArguments(std::vector<void*>& externalTensors);
void execute(); void execute();
void cleanup(); void cleanup();
mlir::Type get_mlir_type(const descriptor::Tensor* tensor); mlir::Type getMlirType(const descriptor::Tensor* tensor);
mlir::Type get_mlir_type(const element::Type& type); mlir::Type getMlirType(const element::Type& type);
mlir::Type get_mlir_type(const ngraph::Node* node); mlir::Type getMlirType(const ngraph::Node* node);
TensorInfo get_tensor_value(descriptor::Tensor* tensor); TensorInfo getTensorValue(descriptor::Tensor* tensor);
void update_tensor_value(descriptor::Tensor* tensor, mlir::Value* value); void updateTensorValue(descriptor::Tensor* tensor, mlir::Value* value);
void build_ng_dialect(); void buildNgDialect();
template <typename Op> template <typename Op>
static mlir::Operation* create_op(MLIRCompiler& compiler, static mlir::Operation* createOp(MLIRCompiler& compiler, const ngraph::Node* ngNode)
const ngraph::Node* ng_node)
{ {
throw std::runtime_error("Unimplemented op '" + ng_node->description() + throw std::runtime_error("Unimplemented op '" + ngNode->description() +
"' in MLIR Compiler"); "' in MLIR Compiler");
} }
// Generic op lowerer to ng dialect. // Generic op lowerer to ng dialect.
// Simply maps ngraph tensors to values and generate an OP. No op-specific logic. // Simply maps ngraph tensors to values and generate an OP. No op-specific logic.
template <typename Op> template <typename Op>
mlir::Operation* create_generic_op(const ngraph::Node* ng_node); mlir::Operation* createGenericOp(const ngraph::Node* ngNode);
template <typename RedOp> template <typename RedOp>
mlir::Operation* create_index_reduction(const ngraph::Node* ng_node); mlir::Operation* createIndexReduction(const ngraph::Node* ngNode);
void create_return(); void createReturn();
/// Helper to create memref arguments for MLIR function signature /// Helper to create memref arguments for MLIR function signature
llvm::SmallVector<void*, 8> allocate_memref_args(); llvm::SmallVector<void*, 8> allocateMemrefArgs();
/// Helper to allocate a mem ref object. Handles static shapes only for now. /// Helper to allocate a mem ref object. Handles static shapes only for now.
mlir::StaticFloatMemRef* allocate_memref_descriptor(); mlir::StaticFloatMemRef* allocateMemrefDescriptor();
/// Helper to dump MLIR module into llvm::dbgs prepended by the message \p msg. /// Helper to dump MLIR module into llvm::dbgs prepended by the message \p msg.
void dump_mlir_module(const std::string msg); void dumpMlirModule(const std::string msg);
/// Converts nGraph shape-like types \p ng_shape to MLIR shape \p mlir_shape. /// Converts nGraph shape-like types \p ng_shape to MLIR shape \p mlir_shape.
template <typename T> template <typename T>
void get_mlir_shape(T ng_shape, llvm::SmallVectorImpl<int64_t>& mlir_shape); void getMlirShape(T ngShape, llvm::SmallVectorImpl<int64_t>& mlirShape);
/// Converts an ngraph shape to an I64 array attribute /// Converts an ngraph shape to an I64 array attribute
template <typename T> template <typename T>
mlir::ArrayAttr get_shape_as_attr(T ng_shape); mlir::ArrayAttr getShapeAsAttr(T ngShape);
private: private:
// Sub-graph to be compiled and executed with MLIR. // Sub-graph to be compiled and executed with MLIR.
const ngraph::op::CompiledKernel* m_compiled_kernel; const ngraph::op::CompiledKernel* m_compiledKernel;
// Pointers to externally allocated memory for sub-graph's input and output tensors. // Pointers to externally allocated memory for sub-graph's input and output tensors.
std::vector<void*>* m_external_tensors; std::vector<void*>* m_externalTensors;
// Arguments for the MLIR function generated for the nGraph sub-graph. // Arguments for the MLIR function generated for the nGraph sub-graph.
llvm::SmallVector<void*, 8> m_invoke_args; llvm::SmallVector<void*, 8> m_invokeArgs;
// MLIR context that holds all the MLIR information related to the sub-graph // MLIR context that holds all the MLIR information related to the sub-graph
// compilation. // compilation.
...@@ -164,11 +163,11 @@ namespace ngraph ...@@ -164,11 +163,11 @@ namespace ngraph
// Maps tensor to the value it represents in the IR // Maps tensor to the value it represents in the IR
// use for MLIR dialect gen // use for MLIR dialect gen
TensorToInfoMap m_tensor_to_value_map; TensorToInfoMap m_tensorToValueMap;
static const MLIRCompOpMap op_dispatcher; static const MLIRCompOpMap opDispatcher;
// Optimization level used by MLIR and LLVM compilers. // Optimization level used by MLIR and LLVM compilers.
static unsigned mlir_opt_level; static unsigned mlirOptLevel;
// LLVM target machine to be used by this MLIR compiler instance to retrieve // LLVM target machine to be used by this MLIR compiler instance to retrieve
// information about target features. // information about target features.
...@@ -178,7 +177,7 @@ namespace ngraph ...@@ -178,7 +177,7 @@ namespace ngraph
// machine or configuration flags. // machine or configuration flags.
// TODO: Move target machine to external nGraph backend when multiple backends start // TODO: Move target machine to external nGraph backend when multiple backends start
// to use MLIR. // to use MLIR.
static std::unique_ptr<llvm::TargetMachine> target_machine; static std::unique_ptr<llvm::TargetMachine> targetMachine;
}; };
} }
} }
......
...@@ -51,7 +51,9 @@ static mlir::LogicalResult verifyCompatibleOperandsAndResults(T* op, bool checkR ...@@ -51,7 +51,9 @@ static mlir::LogicalResult verifyCompatibleOperandsAndResults(T* op, bool checkR
for (auto operand : opr->getOperands()) for (auto operand : opr->getOperands())
{ {
if (i == 0) if (i == 0)
{
continue; continue;
}
mlir::Type t = operand->getType(); mlir::Type t = operand->getType();
mlir::NGTensorType opType = t.cast<NGTensorType>(); mlir::NGTensorType opType = t.cast<NGTensorType>();
if (!opType.isCompatible(opType0)) if (!opType.isCompatible(opType0))
......
...@@ -82,10 +82,14 @@ bool NGTensorType::isCompatible(NGTensorType& other) const ...@@ -82,10 +82,14 @@ bool NGTensorType::isCompatible(NGTensorType& other) const
{ {
// Exact same tensor // Exact same tensor
if (this == &other) if (this == &other)
{
return true; return true;
}
// different tensors, check if of same element type and compatible shapes // different tensors, check if of same element type and compatible shapes
if (getElementType() != other.getElementType()) if (getElementType() != other.getElementType())
{
return false; return false;
}
// TODO: Handle dynamic ranks // TODO: Handle dynamic ranks
// MLIR MemRefType doesn't seem to support it at the moment. // MLIR MemRefType doesn't seem to support it at the moment.
return isCompatibleShape(other); return isCompatibleShape(other);
...@@ -97,7 +101,9 @@ bool NGTensorType::isCompatibleShape(NGTensorType& other) const ...@@ -97,7 +101,9 @@ bool NGTensorType::isCompatibleShape(NGTensorType& other) const
auto otherShape = other.getShape(); auto otherShape = other.getShape();
if (shape.size() != otherShape.size()) if (shape.size() != otherShape.size())
{
return false; return false;
}
for (auto i = 0; i < shape.size(); i++) for (auto i = 0; i < shape.size(); i++)
{ {
...@@ -105,7 +111,9 @@ bool NGTensorType::isCompatibleShape(NGTensorType& other) const ...@@ -105,7 +111,9 @@ bool NGTensorType::isCompatibleShape(NGTensorType& other) const
NGRAPH_CHECK(otherShape[i] >= -1, "Invalid tensor shape", otherShape[i]); NGRAPH_CHECK(otherShape[i] >= -1, "Invalid tensor shape", otherShape[i]);
if (shape[i] == -1 || otherShape[i] == -1 || shape[i] == otherShape[i]) if (shape[i] == -1 || otherShape[i] == -1 || shape[i] == otherShape[i])
{
continue; continue;
}
return false; return false;
} }
return true; return true;
......
...@@ -104,13 +104,19 @@ namespace ...@@ -104,13 +104,19 @@ namespace
// Convert the original function arguments. // Convert the original function arguments.
TypeConverter::SignatureConversion result(type.getNumInputs()); TypeConverter::SignatureConversion result(type.getNumInputs());
for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i) for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
{
if (failed(converter.convertSignatureArg(i, type.getInput(i), result))) if (failed(converter.convertSignatureArg(i, type.getInput(i), result)))
{
return matchFailure(); return matchFailure();
}
}
// Convert the original function results. // Convert the original function results.
SmallVector<Type, 4> convertedResults; SmallVector<Type, 4> convertedResults;
if (failed(converter.convertTypes(type.getResults(), convertedResults))) if (failed(converter.convertTypes(type.getResults(), convertedResults)))
{
return matchFailure(); return matchFailure();
}
// Add result types as input args without mapping // Add result types as input args without mapping
result.addInputs(convertedResults); result.addInputs(convertedResults);
...@@ -139,16 +145,16 @@ namespace ...@@ -139,16 +145,16 @@ namespace
DialectLoweringPass& pass); DialectLoweringPass& pass);
template <typename OP> template <typename OP>
void lower_binary_elementwise(Operation* op, void lowerBinaryElementwise(Operation* op,
ArrayRef<Value*> operands, ArrayRef<Value*> operands,
PatternRewriter& rewriter, PatternRewriter& rewriter,
DialectLoweringPass& pass); DialectLoweringPass& pass);
template <typename OP> template <typename OP>
void lower_unary_elementwise(Operation* op, void lowerUnaryElementwise(Operation* op,
ArrayRef<Value*> operands, ArrayRef<Value*> operands,
PatternRewriter& rewriter, PatternRewriter& rewriter,
DialectLoweringPass& pass); DialectLoweringPass& pass);
ValueHandle createZeroConstant(mlir::Type type); ValueHandle createZeroConstant(mlir::Type type);
...@@ -376,49 +382,49 @@ namespace ...@@ -376,49 +382,49 @@ namespace
REWRITER(NGAddOp) REWRITER(NGAddOp)
{ {
lower_binary_elementwise<mlir::NGAddOp>(op, operands, rewriter, pass); lowerBinaryElementwise<mlir::NGAddOp>(op, operands, rewriter, pass);
return matchSuccess(); return matchSuccess();
} }
REWRITER(NGSubOp) REWRITER(NGSubOp)
{ {
lower_binary_elementwise<mlir::NGSubOp>(op, operands, rewriter, pass); lowerBinaryElementwise<mlir::NGSubOp>(op, operands, rewriter, pass);
return matchSuccess(); return matchSuccess();
} }
REWRITER(NGMulOp) REWRITER(NGMulOp)
{ {
lower_binary_elementwise<mlir::NGMulOp>(op, operands, rewriter, pass); lowerBinaryElementwise<mlir::NGMulOp>(op, operands, rewriter, pass);
return matchSuccess(); return matchSuccess();
} }
REWRITER(NGDivOp) REWRITER(NGDivOp)
{ {
lower_binary_elementwise<mlir::NGDivOp>(op, operands, rewriter, pass); lowerBinaryElementwise<mlir::NGDivOp>(op, operands, rewriter, pass);
return matchSuccess(); return matchSuccess();
} }
REWRITER(NGGreaterOp) REWRITER(NGGreaterOp)
{ {
lower_binary_elementwise<mlir::NGGreaterOp>(op, operands, rewriter, pass); lowerBinaryElementwise<mlir::NGGreaterOp>(op, operands, rewriter, pass);
return matchSuccess(); return matchSuccess();
} }
REWRITER(NGLessOp) REWRITER(NGLessOp)
{ {
lower_binary_elementwise<mlir::NGLessOp>(op, operands, rewriter, pass); lowerBinaryElementwise<mlir::NGLessOp>(op, operands, rewriter, pass);
return matchSuccess(); return matchSuccess();
} }
REWRITER(NGMaxOp) REWRITER(NGMaxOp)
{ {
lower_binary_elementwise<mlir::NGMaxOp>(op, operands, rewriter, pass); lowerBinaryElementwise<mlir::NGMaxOp>(op, operands, rewriter, pass);
return matchSuccess(); return matchSuccess();
} }
REWRITER(NGMinOp) REWRITER(NGMinOp)
{ {
lower_binary_elementwise<mlir::NGMinOp>(op, operands, rewriter, pass); lowerBinaryElementwise<mlir::NGMinOp>(op, operands, rewriter, pass);
return matchSuccess(); return matchSuccess();
} }
...@@ -477,7 +483,7 @@ namespace ...@@ -477,7 +483,7 @@ namespace
// Negative // Negative
REWRITER(NGNegOp) REWRITER(NGNegOp)
{ {
lower_unary_elementwise<mlir::NGNegOp>(op, operands, rewriter, pass); lowerUnaryElementwise<mlir::NGNegOp>(op, operands, rewriter, pass);
return matchSuccess(); return matchSuccess();
} }
...@@ -950,10 +956,10 @@ namespace ...@@ -950,10 +956,10 @@ namespace
#undef REWRITER #undef REWRITER
/// End of pattern matchers /// End of pattern matchers
template <typename OP> template <typename OP>
void lower_unary_elementwise(Operation* op, void lowerUnaryElementwise(Operation* op,
ArrayRef<Value*> operands, ArrayRef<Value*> operands,
PatternRewriter& rewriter, PatternRewriter& rewriter,
DialectLoweringPass& pass) DialectLoweringPass& pass)
{ {
auto loc = cast<OP>(op).getLoc(); auto loc = cast<OP>(op).getLoc();
...@@ -999,10 +1005,10 @@ namespace ...@@ -999,10 +1005,10 @@ namespace
} }
template <typename OP> template <typename OP>
void lower_binary_elementwise(Operation* op, void lowerBinaryElementwise(Operation* op,
ArrayRef<Value*> operands, ArrayRef<Value*> operands,
PatternRewriter& rewriter, PatternRewriter& rewriter,
DialectLoweringPass& pass) DialectLoweringPass& pass)
{ {
auto loc = cast<OP>(op).getLoc(); auto loc = cast<OP>(op).getLoc();
auto result = pass.buildOutputDefs(op, rewriter)[0]; auto result = pass.buildOutputDefs(op, rewriter)[0];
...@@ -1138,7 +1144,9 @@ namespace ...@@ -1138,7 +1144,9 @@ namespace
for (auto i = 0; i < vArg.rank(); i++) for (auto i = 0; i < vArg.rank(); i++)
{ {
if (i != axis) if (i != axis)
{
nonRedIVs.push_back(allIVs[i]); nonRedIVs.push_back(allIVs[i]);
}
} }
// Load current min index with integer data type and convert it to index data type. // Load current min index with integer data type and convert it to index data type.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment