Unverified Commit e1d9ee66 authored by Diego Caballero's avatar Diego Caballero Committed by GitHub

[MLIR] Enable bare ptr calling convention (#4336)

* [MLIR] Update MLIR repo

* Nagy's fix

* Changes related to mlir-opt

* Update MLIR commit

* [MLIR] Enable bare ptr calling convention

This PR adds a flag to enable the bare pointer calling convention
in LLVM lowering pass.

* Update MLIR commit and callbacks.

* Disable 'noalias' attribute.

It will be re-introduced in a follow-up commit.

* Remove '__mlir' prefix in callback test

* Address feedback

* Fix EDSC includes

* Move MLIR repo forward

* Update type converter code

* Address feedback

* Enable 'noalias' code

* Address feedback
Co-authored-by: 's avatarAmy Zhuang <amyzhuang97@gmail.com>
Co-authored-by: 's avatarSang Ik Lee <sang.ik.lee@intel.com>
parent d9242244
......@@ -71,7 +71,11 @@ static llvm::cl::opt<unsigned> clLoopTilingCacheSize(
"inferred from the host CPU using for the cache level specified by "
"-ngraph-loop-tile-cache-level."));
// Enable the lowering of MemRefs to LLVM bare pointers.
extern llvm::cl::opt<bool> clEnableBarePtrMemRefLowering;
using namespace ngraph::runtime::ngmlir;
using namespace mlir;
// Default optimization level.
llvm::CodeGenOpt::Level MLIRCPUBackend::mlirOptLevel = llvm::CodeGenOpt::Level::Aggressive;
......@@ -194,8 +198,19 @@ void MLIRCPUBackend::lowerNgDialect()
void MLIRCPUBackend::lowerStandardDialect()
{
mlir::PassManager pm(&m_context);
// We lower memrefs to a fat memref descriptor by default. If 'clEnableBarePtrMemRefLowering' is
// specified, we lower memref arguments to bare pointers to the memref element type.
if (clEnableBarePtrMemRefLowering)
{
pm.addPass(mlir::createLowerToLLVMPass(/*useAlloca=*/false,
/*useBarePtrCallConv=*/true,
/*emitCWrappers=*/false));
}
else
{
pm.addPass(mlir::createLowerToLLVMPass(
/*useAlloca=*/false, /*useBarePtrCallConv=*/false, /*emitCWrappers=*/true));
}
// Apply any generic pass manager command line options.
mlir::applyPassManagerCLOptions(pm);
......
......@@ -43,6 +43,9 @@
#define PASS_NAME "convert-ngraph-to-affine"
#define DEBUG_TYPE PASS_NAME
// Enable the lowering of MemRefs to LLVM bare pointers.
extern llvm::cl::opt<bool> clEnableBarePtrMemRefLowering;
std::vector<ngraph::runtime::ngmlir::opAttrs> opAttrsVec;
// anonymous namespace
......@@ -344,8 +347,10 @@ namespace
// TODO: Encode no alias attribute as part of the function signature conversion or as a
// separate rewrite pattern. Retrieve new function after signature conversion.
// TODO: To be enabled in follow-up commit.
// insertNoAliasArgAttrs();
if (clEnableBarePtrMemRefLowering)
{
insertNoAliasArgAttrs();
}
}
opAttrsVec = m_attrsVec;
......@@ -520,22 +525,22 @@ namespace
/// Add llvm.noalias attribute to all the memref function arguments. We know that this is safe
/// by nGraph op semantics.
// void DialectLoweringPass::insertNoAliasArgAttrs()
//{
// FuncOp func = getModule().lookupSymbol<mlir::FuncOp>(funcName);
// NGRAPH_CHECK(func, "FuncOp '" + funcName.str() + "' not found");
// unsigned int argIdx = 0;
// for (auto arg : func.getArguments())
// {
// if (arg.getType().isa<MemRefType>())
// {
// func.setArgAttr(argIdx, "llvm.noalias", BoolAttr::get(true, &getContext()));
// }
// ++argIdx;
// }
//}
void DialectLoweringPass::insertNoAliasArgAttrs()
{
FuncOp func = getModule().lookupSymbol<mlir::FuncOp>(funcName);
NGRAPH_CHECK(func, "FuncOp '" + funcName.str() + "' not found");
unsigned int argIdx = 0;
for (auto arg : func.getArguments())
{
if (arg.getType().isa<MemRefType>())
{
func.setArgAttr(argIdx, "llvm.noalias", BoolAttr::get(true, &getContext()));
}
++argIdx;
}
}
void DialectLoweringPass::insertDeallocs(PatternRewriter& rewriter)
{
......
......@@ -53,6 +53,13 @@ static llvm::cl::opt<std::string>
clObjectFilename("ngraph-mlir-object-filename",
llvm::cl::desc("Dump MLIR JITted-compiled object to file jitted_mlir.o"));
// The bare pointer calling convention lowers memref arguments to bare pointers to the memref
// element type.
llvm::cl::opt<bool> clEnableBarePtrMemRefLowering(
"ngraph-bare-ptr-memref-lowering",
llvm::cl::init(false),
llvm::cl::desc("Enable the lowering of MemRefs to LLVM bare pointers"));
void MLIRCPURuntime::run(const std::vector<MemRefArg>& args)
{
// run_internal(*reinterpret_cast<std::vector<void*>*>(args), shapeVec, stridesVec);
......@@ -108,6 +115,9 @@ void MLIRCPURuntime::bindArguments(const std::vector<MemRefArg>& args)
// Assign external tensor pointers to invocation arguments.
for (size_t i = 0, numArgs = m_invokeArgs.size(); i < numArgs; ++i)
{
if (!clEnableBarePtrMemRefLowering)
{
// Default memref lowering lowers memrefs to StaticMemRef descriptors.
auto* memRefArg = *(reinterpret_cast<StaticMemRef**>(m_invokeArgs[i]));
memRefArg->allocatedPtr = (*m_externalTensors)[i].m_tensor;
memRefArg->alignedPtr = (*m_externalTensors)[i].m_tensor;
......@@ -118,6 +128,13 @@ void MLIRCPURuntime::bindArguments(const std::vector<MemRefArg>& args)
memRefArg->shapeAndStrides[rank + j] = (*m_externalTensors)[i].m_strides[j];
}
}
else
{
// Custom memref lowering lowers memref arguments to bare pointers to tensors.
auto** memRefArg = reinterpret_cast<void**>(m_invokeArgs[i]);
*memRefArg = (*m_externalTensors)[i].m_tensor;
}
}
}
// Lowers standard dialect to LLVM dialect and uses the MLIR execution engine to execute the code.
......@@ -143,33 +160,62 @@ void MLIRCPURuntime::cleanup()
// Free void double pointer arguments without freeing external tensor data.
for (auto* arg : m_invokeArgs)
{
if (!clEnableBarePtrMemRefLowering)
{
// Default memref lowering lowers memrefs to StaticMemRef descriptors.
auto* memRefArg = *(reinterpret_cast<StaticMemRef**>(arg));
free(memRefArg);
free(arg);
}
else
{
// Custom memref lowering lowers memref arguments to bare pointers to tensors.
auto** memRefArg = reinterpret_cast<void**>(arg);
free(memRefArg);
}
}
}
// The current call ABI takes a single arg pointer (argPtr) pointing to a list of args.
// The default call ABI takes a single arg pointer (argPtr) pointing to a list of args.
// Each arg is a pointer to a StaticMemRef which contains a data pointer
//
// The args are laid out as follows
// argPtr-> arg[0]-> StaticMemRef -> <data>
// arg[1]-> StaticMemRef -> <data>
// ...
//
// The bare pointer ABI takes a single arg pointer pointing to data for that MemRef. Not extra
// information about the MemRef is passed at the moment. Example:
//
// Args are laid out as follows:
// arg0Ptr-> <data>
// arg1Ptr-> <data>
// ...
SmallVector<void*, 8> MLIRCPURuntime::allocateMemrefArgs()
{
SmallVector<void*, 8> args;
for (auto i = 0; i < m_externalTensors->size(); i++)
{
auto descriptor = allocateMemrefDescriptor(m_ranks[i]);
if (!clEnableBarePtrMemRefLowering)
{
// Default memref lowering lowers memrefs to StaticMemRef descriptors.
auto descriptor = allocateDefaultMemrefDescriptor(m_ranks[i]);
StaticMemRef** arg = reinterpret_cast<StaticMemRef**>(malloc(sizeof(StaticMemRef*)));
*arg = descriptor;
args.push_back(arg);
}
else
{
// Custom memref lowering lowers memref arguments to bare pointers to tensors.
auto** arg = reinterpret_cast<void**>(malloc(sizeof(void**)));
*arg = reinterpret_cast<void*>(malloc(sizeof(void*)));
args.push_back(arg);
}
}
return args;
}
StaticMemRef* MLIRCPURuntime::allocateMemrefDescriptor(size_t rank)
StaticMemRef* MLIRCPURuntime::allocateDefaultMemrefDescriptor(size_t rank)
{
// We only use StaticMemRef because that's what MLIR currently offers.
// We should expand this with different types and dynamic MemRefs
......
......@@ -69,8 +69,9 @@ namespace ngraph
/// Helper to create memref arguments for MLIR function signature
llvm::SmallVector<void*, 8> allocateMemrefArgs();
/// Helper to allocate a mem ref object. Handles static shapes only for now.
StaticMemRef* allocateMemrefDescriptor(size_t);
/// Helper to allocate a default MemRef descriptor for LLVM. Handles static shapes
/// only for now.
StaticMemRef* allocateDefaultMemrefDescriptor(size_t);
private:
// Pointers to externally allocated memory for sub-graph's input and output tensors.
......
// RUN: ngraph-opt %s -convert-ngraph-to-affine -ngraph-bare-ptr-memref-lowering -split-input-file | FileCheck %s --check-prefix=BARE-PTR-CC
// RUN: ngraph-opt %s -convert-ngraph-to-affine -split-input-file | FileCheck %s --check-prefix=STD-CC
// Tests related to the bare pointer calling convention.
// Verify that the `noalias` attribute is generated when the bare pointer calling
// convention is used but not with the standard calling convention.
func @noalias_attribute(%arg0: !ng.tensor<16x!ng.i64>, %arg1: !ng.tensor<512x32xf32>){
"ng.return"() : () -> ()
}
// BARE-PTR-CC-LABEL: func @noalias_attribute
// BARE-PTR-CC-SAME: %{{.*}}: memref<16xi64> {llvm.noalias = true}
// BARE-PTR-CC-SAME: %{{.*}}: memref<512x32xf32> {llvm.noalias = true})
// STD-CC-LABEL: func @noalias_attribute
// STD-CC-NOT: llvm.noalias
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment