Unverified Commit e1d9ee66 authored by Diego Caballero's avatar Diego Caballero Committed by GitHub

[MLIR] Enable bare ptr calling convention (#4336)

* [MLIR] Update MLIR repo

* Nagy's fix

* Changes related to mlir-opt

* Update MLIR commit

* [MLIR] Enable bare ptr calling convention

This PR adds a flag to enable the bare pointer calling convention
in LLVM lowering pass.

* Update MLIR commit and callbacks.

* Disable 'noalias' attribute.

It will be re-introduced in a follow-up commit.

* Remove '__mlir' prefix in callback test

* Address feedback

* Fix EDSC includes

* Move MLIR repo forward

* Update type converter code

* Address feedback

* Enable 'noalias' code

* Address feedback
Co-authored-by: 's avatarAmy Zhuang <amyzhuang97@gmail.com>
Co-authored-by: 's avatarSang Ik Lee <sang.ik.lee@intel.com>
parent d9242244
...@@ -71,7 +71,11 @@ static llvm::cl::opt<unsigned> clLoopTilingCacheSize( ...@@ -71,7 +71,11 @@ static llvm::cl::opt<unsigned> clLoopTilingCacheSize(
"inferred from the host CPU using for the cache level specified by " "inferred from the host CPU using for the cache level specified by "
"-ngraph-loop-tile-cache-level.")); "-ngraph-loop-tile-cache-level."));
// Enable the lowering of MemRefs to LLVM bare pointers.
extern llvm::cl::opt<bool> clEnableBarePtrMemRefLowering;
using namespace ngraph::runtime::ngmlir; using namespace ngraph::runtime::ngmlir;
using namespace mlir;
// Default optimization level. // Default optimization level.
llvm::CodeGenOpt::Level MLIRCPUBackend::mlirOptLevel = llvm::CodeGenOpt::Level::Aggressive; llvm::CodeGenOpt::Level MLIRCPUBackend::mlirOptLevel = llvm::CodeGenOpt::Level::Aggressive;
...@@ -194,8 +198,19 @@ void MLIRCPUBackend::lowerNgDialect() ...@@ -194,8 +198,19 @@ void MLIRCPUBackend::lowerNgDialect()
void MLIRCPUBackend::lowerStandardDialect() void MLIRCPUBackend::lowerStandardDialect()
{ {
mlir::PassManager pm(&m_context); mlir::PassManager pm(&m_context);
pm.addPass(mlir::createLowerToLLVMPass( // We lower memrefs to a fat memref descriptor by default. If 'clEnableBarePtrMemRefLowering' is
/*useAlloca=*/false, /*useBarePtrCallConv=*/false, /*emitCWrappers=*/true)); // specified, we lower memref arguments to bare pointers to the memref element type.
if (clEnableBarePtrMemRefLowering)
{
pm.addPass(mlir::createLowerToLLVMPass(/*useAlloca=*/false,
/*useBarePtrCallConv=*/true,
/*emitCWrappers=*/false));
}
else
{
pm.addPass(mlir::createLowerToLLVMPass(
/*useAlloca=*/false, /*useBarePtrCallConv=*/false, /*emitCWrappers=*/true));
}
// Apply any generic pass manager command line options. // Apply any generic pass manager command line options.
mlir::applyPassManagerCLOptions(pm); mlir::applyPassManagerCLOptions(pm);
......
...@@ -43,6 +43,9 @@ ...@@ -43,6 +43,9 @@
#define PASS_NAME "convert-ngraph-to-affine" #define PASS_NAME "convert-ngraph-to-affine"
#define DEBUG_TYPE PASS_NAME #define DEBUG_TYPE PASS_NAME
// Enable the lowering of MemRefs to LLVM bare pointers.
extern llvm::cl::opt<bool> clEnableBarePtrMemRefLowering;
std::vector<ngraph::runtime::ngmlir::opAttrs> opAttrsVec; std::vector<ngraph::runtime::ngmlir::opAttrs> opAttrsVec;
// anonymous namespace // anonymous namespace
...@@ -344,8 +347,10 @@ namespace ...@@ -344,8 +347,10 @@ namespace
// TODO: Encode no alias attribute as part of the function signature conversion or as a // TODO: Encode no alias attribute as part of the function signature conversion or as a
// separate rewrite pattern. Retrieve new function after signature conversion. // separate rewrite pattern. Retrieve new function after signature conversion.
// TODO: To be enabled in follow-up commit. if (clEnableBarePtrMemRefLowering)
// insertNoAliasArgAttrs(); {
insertNoAliasArgAttrs();
}
} }
opAttrsVec = m_attrsVec; opAttrsVec = m_attrsVec;
...@@ -520,22 +525,22 @@ namespace ...@@ -520,22 +525,22 @@ namespace
/// Add llvm.noalias attribute to all the memref function arguments. We know that this is safe /// Add llvm.noalias attribute to all the memref function arguments. We know that this is safe
/// by nGraph op semantics. /// by nGraph op semantics.
// void DialectLoweringPass::insertNoAliasArgAttrs() void DialectLoweringPass::insertNoAliasArgAttrs()
//{ {
// FuncOp func = getModule().lookupSymbol<mlir::FuncOp>(funcName); FuncOp func = getModule().lookupSymbol<mlir::FuncOp>(funcName);
// NGRAPH_CHECK(func, "FuncOp '" + funcName.str() + "' not found"); NGRAPH_CHECK(func, "FuncOp '" + funcName.str() + "' not found");
// unsigned int argIdx = 0; unsigned int argIdx = 0;
// for (auto arg : func.getArguments()) for (auto arg : func.getArguments())
// { {
// if (arg.getType().isa<MemRefType>()) if (arg.getType().isa<MemRefType>())
// { {
// func.setArgAttr(argIdx, "llvm.noalias", BoolAttr::get(true, &getContext())); func.setArgAttr(argIdx, "llvm.noalias", BoolAttr::get(true, &getContext()));
// } }
// ++argIdx; ++argIdx;
// } }
//} }
void DialectLoweringPass::insertDeallocs(PatternRewriter& rewriter) void DialectLoweringPass::insertDeallocs(PatternRewriter& rewriter)
{ {
......
...@@ -53,6 +53,13 @@ static llvm::cl::opt<std::string> ...@@ -53,6 +53,13 @@ static llvm::cl::opt<std::string>
clObjectFilename("ngraph-mlir-object-filename", clObjectFilename("ngraph-mlir-object-filename",
llvm::cl::desc("Dump MLIR JITted-compiled object to file jitted_mlir.o")); llvm::cl::desc("Dump MLIR JITted-compiled object to file jitted_mlir.o"));
// The bare pointer calling convention lowers memref arguments to bare pointers to the memref
// element type.
llvm::cl::opt<bool> clEnableBarePtrMemRefLowering(
"ngraph-bare-ptr-memref-lowering",
llvm::cl::init(false),
llvm::cl::desc("Enable the lowering of MemRefs to LLVM bare pointers"));
void MLIRCPURuntime::run(const std::vector<MemRefArg>& args) void MLIRCPURuntime::run(const std::vector<MemRefArg>& args)
{ {
// run_internal(*reinterpret_cast<std::vector<void*>*>(args), shapeVec, stridesVec); // run_internal(*reinterpret_cast<std::vector<void*>*>(args), shapeVec, stridesVec);
...@@ -108,14 +115,24 @@ void MLIRCPURuntime::bindArguments(const std::vector<MemRefArg>& args) ...@@ -108,14 +115,24 @@ void MLIRCPURuntime::bindArguments(const std::vector<MemRefArg>& args)
// Assign external tensor pointers to invocation arguments. // Assign external tensor pointers to invocation arguments.
for (size_t i = 0, numArgs = m_invokeArgs.size(); i < numArgs; ++i) for (size_t i = 0, numArgs = m_invokeArgs.size(); i < numArgs; ++i)
{ {
auto* memRefArg = *(reinterpret_cast<StaticMemRef**>(m_invokeArgs[i])); if (!clEnableBarePtrMemRefLowering)
memRefArg->allocatedPtr = (*m_externalTensors)[i].m_tensor; {
memRefArg->alignedPtr = (*m_externalTensors)[i].m_tensor; // Default memref lowering lowers memrefs to StaticMemRef descriptors.
auto rank = m_ranks[i]; auto* memRefArg = *(reinterpret_cast<StaticMemRef**>(m_invokeArgs[i]));
for (auto j = 0; j < rank; j++) memRefArg->allocatedPtr = (*m_externalTensors)[i].m_tensor;
memRefArg->alignedPtr = (*m_externalTensors)[i].m_tensor;
auto rank = m_ranks[i];
for (auto j = 0; j < rank; j++)
{
memRefArg->shapeAndStrides[j] = (*m_externalTensors)[i].m_shape[j];
memRefArg->shapeAndStrides[rank + j] = (*m_externalTensors)[i].m_strides[j];
}
}
else
{ {
memRefArg->shapeAndStrides[j] = (*m_externalTensors)[i].m_shape[j]; // Custom memref lowering lowers memref arguments to bare pointers to tensors.
memRefArg->shapeAndStrides[rank + j] = (*m_externalTensors)[i].m_strides[j]; auto** memRefArg = reinterpret_cast<void**>(m_invokeArgs[i]);
*memRefArg = (*m_externalTensors)[i].m_tensor;
} }
} }
} }
...@@ -143,33 +160,62 @@ void MLIRCPURuntime::cleanup() ...@@ -143,33 +160,62 @@ void MLIRCPURuntime::cleanup()
// Free void double pointer arguments without freeing external tensor data. // Free void double pointer arguments without freeing external tensor data.
for (auto* arg : m_invokeArgs) for (auto* arg : m_invokeArgs)
{ {
auto* memRefArg = *(reinterpret_cast<StaticMemRef**>(arg)); if (!clEnableBarePtrMemRefLowering)
free(memRefArg); {
free(arg); // Default memref lowering lowers memrefs to StaticMemRef descriptors.
auto* memRefArg = *(reinterpret_cast<StaticMemRef**>(arg));
free(memRefArg);
free(arg);
}
else
{
// Custom memref lowering lowers memref arguments to bare pointers to tensors.
auto** memRefArg = reinterpret_cast<void**>(arg);
free(memRefArg);
}
} }
} }
// The current call ABI takes a single arg pointer (argPtr) pointing to a list of args. // The default call ABI takes a single arg pointer (argPtr) pointing to a list of args.
// Each arg is a pointer to a StaticMemRef which contains a data pointer // Each arg is a pointer to a StaticMemRef which contains a data pointer
// //
// The args are laid out as follows // The args are laid out as follows
// argPtr-> arg[0]-> StaticMemRef -> <data> // argPtr-> arg[0]-> StaticMemRef -> <data>
// arg[1]-> StaticMemRef -> <data> // arg[1]-> StaticMemRef -> <data>
// ... // ...
//
// The bare pointer ABI takes a single arg pointer pointing to data for that MemRef. Not extra
// information about the MemRef is passed at the moment. Example:
//
// Args are laid out as follows:
// arg0Ptr-> <data>
// arg1Ptr-> <data>
// ...
SmallVector<void*, 8> MLIRCPURuntime::allocateMemrefArgs() SmallVector<void*, 8> MLIRCPURuntime::allocateMemrefArgs()
{ {
SmallVector<void*, 8> args; SmallVector<void*, 8> args;
for (auto i = 0; i < m_externalTensors->size(); i++) for (auto i = 0; i < m_externalTensors->size(); i++)
{ {
auto descriptor = allocateMemrefDescriptor(m_ranks[i]); if (!clEnableBarePtrMemRefLowering)
StaticMemRef** arg = reinterpret_cast<StaticMemRef**>(malloc(sizeof(StaticMemRef*))); {
*arg = descriptor; // Default memref lowering lowers memrefs to StaticMemRef descriptors.
args.push_back(arg); auto descriptor = allocateDefaultMemrefDescriptor(m_ranks[i]);
StaticMemRef** arg = reinterpret_cast<StaticMemRef**>(malloc(sizeof(StaticMemRef*)));
*arg = descriptor;
args.push_back(arg);
}
else
{
// Custom memref lowering lowers memref arguments to bare pointers to tensors.
auto** arg = reinterpret_cast<void**>(malloc(sizeof(void**)));
*arg = reinterpret_cast<void*>(malloc(sizeof(void*)));
args.push_back(arg);
}
} }
return args; return args;
} }
StaticMemRef* MLIRCPURuntime::allocateMemrefDescriptor(size_t rank) StaticMemRef* MLIRCPURuntime::allocateDefaultMemrefDescriptor(size_t rank)
{ {
// We only use StaticMemRef because that's what MLIR currently offers. // We only use StaticMemRef because that's what MLIR currently offers.
// We should expand this with different types and dynamic MemRefs // We should expand this with different types and dynamic MemRefs
......
...@@ -69,8 +69,9 @@ namespace ngraph ...@@ -69,8 +69,9 @@ namespace ngraph
/// Helper to create memref arguments for MLIR function signature /// Helper to create memref arguments for MLIR function signature
llvm::SmallVector<void*, 8> allocateMemrefArgs(); llvm::SmallVector<void*, 8> allocateMemrefArgs();
/// Helper to allocate a mem ref object. Handles static shapes only for now. /// Helper to allocate a default MemRef descriptor for LLVM. Handles static shapes
StaticMemRef* allocateMemrefDescriptor(size_t); /// only for now.
StaticMemRef* allocateDefaultMemrefDescriptor(size_t);
private: private:
// Pointers to externally allocated memory for sub-graph's input and output tensors. // Pointers to externally allocated memory for sub-graph's input and output tensors.
......
// RUN: ngraph-opt %s -convert-ngraph-to-affine -ngraph-bare-ptr-memref-lowering -split-input-file | FileCheck %s --check-prefix=BARE-PTR-CC
// RUN: ngraph-opt %s -convert-ngraph-to-affine -split-input-file | FileCheck %s --check-prefix=STD-CC
// Tests related to the bare pointer calling convention.
// Verify that the `noalias` attribute is generated when the bare pointer calling
// convention is used but not with the standard calling convention.
func @noalias_attribute(%arg0: !ng.tensor<16x!ng.i64>, %arg1: !ng.tensor<512x32xf32>){
"ng.return"() : () -> ()
}
// BARE-PTR-CC-LABEL: func @noalias_attribute
// BARE-PTR-CC-SAME: %{{.*}}: memref<16xi64> {llvm.noalias = true}
// BARE-PTR-CC-SAME: %{{.*}}: memref<512x32xf32> {llvm.noalias = true})
// STD-CC-LABEL: func @noalias_attribute
// STD-CC-NOT: llvm.noalias
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment