Commit e941412e authored by Nagy Mostafa's avatar Nagy Mostafa Committed by nmostafa

[MLIR] Mem manager (#9)

* Implements a simple memory manager that just does malloc for now. Pointers are freed during cleanup.
* Enable JIT call-back to memory manager to allocate temps.
* Memory manager pointer is passed to the JIT'ed code upon invocation. That makes the code re-entrant from different threads in case the code is shared among identical sub-graphs that are executed in parallel.
parent ba735a80
......@@ -22,6 +22,7 @@ set(MLIR_SRC
mlir/dialect/ops.cpp
mlir/compiler.cpp
mlir/lowerer.cpp
mlir/memory_manager.cpp
)
set(SRC
cpu_backend.cpp
......@@ -222,10 +223,11 @@ if (NGRAPH_CPU_ENABLE)
add_definitions(${LLVM_DEFINITIONS})
target_include_directories(cpu_backend PRIVATE ${LLVM_INCLUDE_DIRS})
string(REPLACE ":" ";" MLIR_INCLUDE_PATH $ENV{MLIR_INCLUDE_PATH})
message(STATUS "MLIR Headers at : ${MLIR_INCLUDE_PATH}")
target_include_directories(cpu_backend PRIVATE ${MLIR_INCLUDE_PATH})
if(DEFINED ENV{MLIR_INCLUDE_PATH})
string(REPLACE ":" ";" MLIR_INCLUDE_PATH $ENV{MLIR_INCLUDE_PATH})
message(STATUS "MLIR Headers at : ${MLIR_INCLUDE_PATH}")
target_include_directories(cpu_backend PRIVATE ${MLIR_INCLUDE_PATH})
endif()
llvm_map_components_to_libnames(llvm_libs support core irreader)
......
......@@ -21,6 +21,7 @@
#include "ngraph/runtime/cpu/kernel/add.hpp"
#include "ngraph/runtime/cpu/mkldnn_invoke.hpp"
#include "ngraph/runtime/cpu/mkldnn_utils.hpp"
#include "ngraph/runtime/cpu/mlir/compiler.hpp"
using namespace std;
using namespace ngraph;
......
......@@ -59,12 +59,13 @@ namespace ngraph
// Register any LLVM command line options
llvm::cl::ParseEnvironmentOptions("ngraph", "MLIR_LLVM_OPTIONS", "");
}
void MLIRCompiler::compile_and_run()
{
build_module(); // MLIR gen
lower_dialect();
optimize();
bind_tensors_to_arguments();
bind_arguments();
execute();
cleanup();
}
......@@ -114,7 +115,6 @@ namespace ngraph
}
}
/// Collects input and output tensors to this sub-graph
void MLIRCompiler::build_tensors_list()
{
for (const auto node : m_sub_graph)
......@@ -223,7 +223,7 @@ namespace ngraph
void MLIRCompiler::lower_dialect()
{
mlir::PassManager pm;
pm.addPass(createDialectLoweringPass());
pm.addPass(createDialectLoweringPass(this));
pm.addPass(mlir::createCanonicalizerPass());
pm.run(m_module.get());
......@@ -314,7 +314,7 @@ namespace ngraph
m_builder->create<NG_ReturnOp>(mlir::UnknownLoc::get(&m_context), value_list);
}
void MLIRCompiler::bind_tensors_to_arguments()
void MLIRCompiler::bind_arguments()
{
NGRAPH_ASSERT(m_module && "MLIR module is not ready.");
......@@ -326,11 +326,10 @@ namespace ngraph
// SmallVector<StaticFloatMemref*>. StaticFloatMemref is just a struct with the
// actual pointer to the data.
// TODO (dcab): Only f32 arguments are supported for now. We may want to implement
// this more generically by just allocating a void double pointer.
auto expected_arguments = allocateMemRefArguments(func);
NGRAPH_ASSERT(expected_arguments) << "Arguments can't be created";
m_invoke_args = std::move(*expected_arguments);
// create MemRef args
auto expected_arguments = allocate_memref_args(func);
NGRAPH_ASSERT(expected_arguments.size()) << "Arguments can't be created";
m_invoke_args = std::move(expected_arguments);
NGRAPH_ASSERT(m_invoke_args.size() == m_external_tensors.size())
<< "Number of external tensors doesn't match number of function arguments";
......@@ -340,6 +339,15 @@ namespace ngraph
{
((mlir::StaticFloatMemRef*)m_invoke_args[i])->data = (float*)m_external_tensors[i];
}
// Add pointer to memory manager
// malloc here since that's what allocateMemRefArguments use
// TODO (nmostafa): Better way of doing this ? Use builder allocator ?
MLIRMemMgr** mem_mgr_arg = reinterpret_cast<MLIRMemMgr**>(malloc(sizeof(void*)));
*mem_mgr_arg = &get_mem_mgr();
// inserting memory manager ptr in right location ?
NGRAPH_ASSERT(m_invoke_args.size() == get_mem_mgr_arg_id(func));
m_invoke_args.push_back(static_cast<void*>(mem_mgr_arg));
}
void MLIRCompiler::execute()
......@@ -356,11 +364,6 @@ namespace ngraph
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
// Create an MLIR execution engine. Note that it takes a null pass manager
// to make sure it won't run "default" passes on the MLIR that would trigger
// a second conversion to LLVM IR. The execution engine eagerly JIT-compiles
// the module.
// Create an MLIR execution engine. Note that it takes a null pass manager
// to make sure it won't run "default" passes on the MLIR that would trigger
// a second conversion to LLVM IR. The execution engine eagerly JIT-compiles
......@@ -389,5 +392,39 @@ namespace ngraph
// Free MLIR function builder.
if (m_builder)
m_builder.reset(nullptr);
// Free allocated memory for JIT'ed code temps
m_mem_mgr.freeAll();
}
SmallVector<void*, 8> MLIRCompiler::allocate_memref_args(mlir::Function* func)
{
SmallVector<void*, 8> args;
args.reserve(func->getNumArguments());
for (const auto& arg : func->getArguments())
{
auto descriptor = allocate_memref_descriptor(arg->getType());
if (!descriptor)
continue;
args.push_back(descriptor);
}
return args;
}
mlir::StaticFloatMemRef* MLIRCompiler::allocate_memref_descriptor(mlir::Type type)
{
auto memRefType = type.dyn_cast<mlir::MemRefType>();
if (!memRefType)
return nullptr;
if (memRefType.getNumDynamicDims() != 0)
NGRAPH_FAIL();
// We only use StaticFloatMemRef because that's what MLIR currently offers.
// We should expand this with different types and dynamic MemRefs
auto* descriptor =
reinterpret_cast<mlir::StaticFloatMemRef*>(malloc(sizeof(mlir::StaticFloatMemRef)));
descriptor->data = nullptr;
return descriptor;
}
}
......@@ -15,6 +15,8 @@
//*****************************************************************************
#pragma once
#include "lowerer.hpp"
#include "memory_manager.hpp"
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/node.hpp"
......@@ -29,6 +31,11 @@
#include <mlir/IR/Types.h>
#include <mlir/StandardOps/Ops.h>
namespace mlir
{
struct StaticFloatMemRef;
}
namespace ngraph
{
namespace runtime
......@@ -37,6 +44,9 @@ namespace ngraph
{
class MLIRCompiler
{
public:
static void init_mlir();
public:
using TensorList = std::vector<descriptor::Tensor*>;
using TypeList = llvm::SmallVector<mlir::Type, 4>;
......@@ -48,10 +58,17 @@ namespace ngraph
{
}
static void init_mlir();
/// Compiles and runs a subgraph in MLIR.
/// Compiles and runs a subgraph in MLIR
void compile_and_run();
/// Returns the memory manager used by this sub-graph compiler
MLIRMemMgr& get_mem_mgr() { return m_mem_mgr; }
/// Returns memory manager pointer argument ID in call interface
unsigned get_mem_mgr_arg_id(mlir::Function* func)
{
return func->getNumArguments() - 1;
}
private:
struct TensorInfo
{
......@@ -63,15 +80,17 @@ namespace ngraph
void build_module();
void lower_dialect();
void optimize();
void bind_tensors_to_arguments();
void bind_arguments();
void execute();
void cleanup();
/// Collects input and output tensors to this sub-graph
void build_tensors_list();
mlir::Type get_mlir_type(const descriptor::Tensor* tensor);
mlir::Type get_mlir_type(const element::Type& type);
TensorInfo get_tensor_value(descriptor::Tensor* tensor);
void update_tensor_value(descriptor::Tensor* tensor, mlir::Value* value);
void build_ng_dialect();
template <typename OP>
......@@ -85,6 +104,12 @@ namespace ngraph
mlir::Value* create_binary_op(const ngraph::Node* ng_node);
void create_return();
/// Helper to create memref arguments for MLIR function signature
llvm::SmallVector<void*, 8> allocate_memref_args(mlir::Function* func);
/// Helper to allocate a mem ref object. Handles static shapes only for now.
mlir::StaticFloatMemRef* allocate_memref_descriptor(mlir::Type type);
private:
mlir::MLIRContext m_context;
std::unique_ptr<mlir::Module> m_module;
......@@ -107,6 +132,9 @@ namespace ngraph
// List of input and output tensors in the graph
TensorList m_ip_tensors, m_op_tensors;
static const MLIRCompOpMap op_dispatcher;
// Memory manager for temp allocations inside JIT'ed code
MLIRMemMgr m_mem_mgr;
};
}
}
......
......@@ -31,7 +31,7 @@ namespace ngraph
addOperations<NG_AddOp>();
addOperations<NG_MatmulBiasOp>();
addOperations<NG_ReturnOp>();
addOperations<NG_FakeOutput>();
addOperations<NG_FakeInput>();
}
void NGDialect::printType(mlir::Type type, raw_ostream& os) const
......
......@@ -69,14 +69,14 @@ namespace ngraph
}
}
void runtime::cpu::NG_FakeOutput::build(mlir::Builder* builder,
mlir::OperationState* state,
mlir::Type resultType)
void runtime::cpu::NG_FakeInput::build(mlir::Builder* builder,
mlir::OperationState* state,
mlir::Type resultType)
{
state->types.push_back(std::move(resultType));
}
mlir::LogicalResult runtime::cpu::NG_FakeOutput::verify()
mlir::LogicalResult runtime::cpu::NG_FakeInput::verify()
{
// TODO: Verify returned tensor types must match function return type.
return mlir::success();
......
......@@ -29,10 +29,15 @@ namespace ngraph
namespace cpu
{
// Fake instructions
class NG_FakeOutput : public mlir::Op<NG_FakeOutput,
mlir::OpTrait::NOperands<0>::Impl,
mlir::OpTrait::OneResult,
mlir::OpTrait::HasNoSideEffect>
/// Fake Input
/// Used as fake definitions during dialect conversion.
/// Used when we cannot insert the real definition once during lowering.
/// The are cleaned up after dialect lowering and replaced with real defintion.
class NG_FakeInput : public mlir::Op<NG_FakeInput,
mlir::OpTrait::NOperands<0>::Impl,
mlir::OpTrait::OneResult,
mlir::OpTrait::HasNoSideEffect>
{
public:
static llvm::StringRef getOperationName() { return "ng.fake.output"; }
......
......@@ -94,6 +94,20 @@ namespace ngraph
EltType getElementType() const { return getImpl()->getElementType(); }
Shape getShape() const { return getImpl()->getShape(); }
int getRank() { return getShape().size(); }
size_t getSizeInBytes()
{
size_t s = 1;
auto shape = getShape();
for (auto i = 0; i < getRank(); i++)
{
// no dynamic dims
if (shape[i] == -1)
return -1;
s *= shape[i];
}
// Multiply times element size
return s * llvm::divideCeil(getElementType().getIntOrFloatBitWidth(), 8);
}
/// convert to memref native MLIR type. Used for lowering.
mlir::MemRefType toMemref();
/// create a unique tensor type based on element type and shape.
......
......@@ -16,6 +16,7 @@
#include "lowerer.hpp"
#include <map>
#include "compiler.hpp"
#include "llvm/ADT/DenseSet.h"
#include "mlir/EDSC/Builders.h"
#include "mlir/EDSC/Helpers.h"
......@@ -69,8 +70,9 @@ namespace
class DialectLoweringPass : public ModulePass<DialectLoweringPass>
{
public:
DialectLoweringPass()
DialectLoweringPass(MLIRCompiler& compiler)
: m_dialectLowerer(*this)
, m_compiler(compiler)
{
}
void runOnModule() override;
......@@ -78,26 +80,25 @@ namespace
SmallVector<Value*, 4> buildOutputDefs(Operation* op, FuncBuilder& rewriter);
private:
mlir::Function* getCallDecl(StringRef name,
ArrayRef<Type> args,
ArrayRef<Type> output,
FuncBuilder& rewriter);
void findOutputValues();
void fixOutputs();
void processFakeInstrs();
Value* insertMemMgrDef(FuncBuilder* rewriter = nullptr);
private:
DialectLowerer m_dialectLowerer;
// Value holding mem manager passed pointer
SmallVector<Value*, 4> m_memMgrDefs;
// maps output ng dialect values to args pos
std::map<Value*, unsigned> m_outputValueMap;
// list of results values to add to func signature
SmallVector<Value*, 4> m_loweredOutputValues;
MLIRCompiler& m_compiler;
};
Type DialectLowerer::convertType(Type t)
{
if (auto tensor = t.cast<NGTensorType>())
{
return tensor.toMemref();
}
return t;
}
void DialectLoweringPass::runOnModule()
{
// capture output values by looking for the Return and grabbing the values
......@@ -115,7 +116,7 @@ namespace
{
getModule().dump();
}
fixOutputs();
processFakeInstrs();
if (std::getenv("NGRAPH_MLIR_DUMP_ALL") != nullptr)
{
getModule().dump();
......@@ -124,6 +125,7 @@ namespace
void DialectLoweringPass::findOutputValues()
{
// get original function
auto f = getModule().getNamedFunction("main");
SmallVector<Value*, 4> outputList;
unsigned outputCount = 0;
......@@ -143,7 +145,136 @@ namespace
m_loweredOutputValues.resize(outputCount, nullptr);
}
/// Inserts a fake def for Mem Mgr pointer at converted func start
Value* DialectLoweringPass::insertMemMgrDef(FuncBuilder* rewriter)
{
// it would be nice to insert one fake def at the start of the new func
// however, due to how DialectConversion framework works, new func is only
// materialized after conversion is done (rewriter->getFunction, or even rewriter->getInsertionBlock()->getFunction()
// will give you the original func). This makes it very convoluted to insert instructions at entry block.
auto op = rewriter->create<NG_FakeInput>(rewriter->getUnknownLoc(),
IndexType::get(getModule().getContext()));
// will be fixed later to read passed arg instead.
m_memMgrDefs.push_back(op.getResult());
return op.getResult();
}
SmallVector<Value*, 4> DialectLoweringPass::buildOutputDefs(Operation* op,
FuncBuilder& rewriter)
{
auto& outputMap = getOutputValueMap();
SmallVector<Value*, 4> newResults;
for (auto origResult : op->getResults())
{
auto it = outputMap.find(origResult);
// create output def if this operation produces any sub-graph outputs
if (it != outputMap.end())
{
unsigned argId = (*it).second;
auto newResult = rewriter
.create<NG_FakeInput>(
op->getLoc(),
m_dialectLowerer.convertType(
origResult->getType()) /* convert to lowered type */
)
.getResult();
newResults.push_back(newResult);
m_loweredOutputValues[argId] = newResult;
}
else
{
auto tensorType = origResult->getType().cast<NGTensorType>();
auto callBackFunc = getCallDecl("__mlir_allocate",
{rewriter.getIndexType(), rewriter.getIndexType()},
{tensorType.toMemref()},
rewriter);
auto size = tensorType.getSizeInBytes();
SmallVector<mlir::Value*, 4> args = {
insertMemMgrDef(&rewriter), /* pointer to mem manager */
rewriter.create<mlir::ConstantIndexOp>(rewriter.getUnknownLoc(),
size)}; /* size to allocate */
auto newResult =
rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args)
.getResult(0);
newResults.push_back(newResult);
}
}
return newResults;
}
void DialectLoweringPass::processFakeInstrs()
{
auto context = getModule().getContext();
auto f = getModule().getNamedFunction("main");
mlir::Block* entryBlock = &*(f->begin());
auto oldFuncType = f->getType();
ArrayRef<mlir::Type> ipArgs = oldFuncType.getInputs();
ArrayRef<mlir::Type> opArgs = oldFuncType.getResults();
SmallVector<mlir::Type, 4> allArgs;
// Move all args as inputs in new type
for (auto type : ipArgs)
{
allArgs.push_back(type);
}
for (auto type : opArgs)
{
allArgs.push_back(type);
// add new value for result
entryBlock->addArgument(type);
}
// Mem Manager Ptr
auto indexType = mlir::IndexType::get(context);
allArgs.push_back(indexType);
entryBlock->addArgument(indexType);
// update type
auto newFuncType = mlir::FunctionType::get(allArgs, {}, context);
f->setType(newFuncType);
// RAUW fake outputs with result values
unsigned i = 0;
for (auto value : m_loweredOutputValues)
{
auto op = value->getDefiningOp();
NGRAPH_ASSERT(op->isa<NG_FakeInput>()) << "output value not defined by fake output?";
value->replaceAllUsesWith(entryBlock->getArgument(oldFuncType.getNumInputs() + i));
op->erase();
i++;
}
for (auto v : m_memMgrDefs)
{
v->replaceAllUsesWith(entryBlock->getArgument(m_compiler.get_mem_mgr_arg_id(f)));
v->getDefiningOp()->erase();
}
}
mlir::Function* DialectLoweringPass::getCallDecl(StringRef name,
ArrayRef<Type> args,
ArrayRef<Type> output,
FuncBuilder& rewriter)
{
auto callBackFuncPtr = getModule().getNamedFunction(name);
if (callBackFuncPtr == nullptr)
{
auto callBackType = rewriter.getFunctionType(args, output);
auto callBackFunc =
llvm::make_unique<mlir::Function>(rewriter.getUnknownLoc(), name, callBackType);
callBackFuncPtr = callBackFunc.get();
getModule().getFunctions().push_back(callBackFunc.release());
}
return callBackFuncPtr;
}
// NGDialect converters
Type DialectLowerer::convertType(Type t)
{
if (auto tensor = t.dyn_cast<NGTensorType>())
{
return tensor.toMemref();
}
return t;
}
// ADD
SmallVector<Value*, 4> NG_AddOpConversion::rewrite(Operation* op,
ArrayRef<Value*> operands,
......@@ -256,69 +387,6 @@ namespace
rewriter.create<ReturnOp>(op->getLoc());
return {};
}
SmallVector<Value*, 4> DialectLoweringPass::buildOutputDefs(Operation* op,
FuncBuilder& rewriter)
{
auto& outputMap = getOutputValueMap();
SmallVector<Value*, 4> newResults;
for (auto origResult : op->getResults())
{
auto it = outputMap.find(origResult);
// create output def if this operation produces any sub-graph outputs
if (it != outputMap.end())
{
unsigned argId = (*it).second;
auto newResult = rewriter
.create<NG_FakeOutput>(
op->getLoc(),
m_dialectLowerer.convertType(
origResult->getType()) /* convert to lowered type */
)
.getResult();
newResults.push_back(newResult);
m_loweredOutputValues[argId] = newResult;
}
}
return newResults;
}
void DialectLoweringPass::fixOutputs()
{
auto context = getModule().getContext();
auto f = getModule().getNamedFunction("main");
mlir::Block* entryBlock = &*(f->begin());
auto oldFuncType = f->getType();
ArrayRef<mlir::Type> ipArgs = oldFuncType.getInputs();
ArrayRef<mlir::Type> opArgs = oldFuncType.getResults();
SmallVector<mlir::Type, 4> allArgs;
// Move all args as inputs in new type
for (auto type : ipArgs)
{
allArgs.push_back(type);
}
for (auto type : opArgs)
{
allArgs.push_back(type);
// add new value for result
entryBlock->addArgument(type);
}
// update type
auto newFuncType = mlir::FunctionType::get(allArgs, {}, context);
f->setType(newFuncType);
// RAUW fake outputs with result values
unsigned i = 0;
for (auto value : m_loweredOutputValues)
{
auto op = value->getDefiningOp();
NGRAPH_ASSERT(op->isa<NG_FakeOutput>()) << "output value not defined by fake output?";
value->replaceAllUsesWith(entryBlock->getArgument(oldFuncType.getNumInputs() + i));
op->erase();
i++;
}
}
}
namespace ngraph
......@@ -327,7 +395,10 @@ namespace ngraph
{
namespace cpu
{
Pass* createDialectLoweringPass() { return new DialectLoweringPass(); }
Pass* createDialectLoweringPass(MLIRCompiler* compiler)
{
return new DialectLoweringPass(*compiler);
}
}
}
}
......@@ -25,7 +25,9 @@ namespace ngraph
{
namespace cpu
{
mlir::Pass* createDialectLoweringPass();
class MLIRCompiler;
mlir::Pass* createDialectLoweringPass(MLIRCompiler* compiler);
}
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "memory_manager.hpp"
#include <llvm/ADT/STLExtras.h>
#include <memory>
#include "compiler.hpp"
#include "ngraph/runtime/cpu/cpu_backend_visibility.h"
using namespace ngraph::runtime::cpu;
/// Call back to allocate memory for temps from JIT'ed code
extern "C" CPU_BACKEND_API void* __mlir_allocate(MLIRMemMgr* mem_mgr, size_t size)
{
return mem_mgr->allocate(size);
}
void* MLIRMemMgr::allocate(size_t size)
{
void* ptr = malloc(size);
ptrList.push_back(ptr);
return ptr;
}
void MLIRMemMgr::freeAll()
{
for (auto p : ptrList)
{
free(p);
}
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <stdint.h>
#include <vector>
namespace ngraph
{
namespace runtime
{
namespace cpu
{
/// Memory manager for temporaries in MLIR compiled sub-graph
/// It handles call-backs from the code and returns pointer to allocated memory
/// Also, handles freeing up memory
class MLIRMemMgr
{
public:
/// Allocates data for temporary tensor. Currently, it is called for each
/// temp tensor defintion. Keeps track of each pointer and free them during cleanup.
// TODO: Use pre-allocation from framework memory manager
void* allocate(size_t size);
/// Frees all allocated pointers
void freeAll();
private:
std::vector<void*> ptrList;
};
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment