Commit 1fb6fa96 authored by Nagy Mostafa's avatar Nagy Mostafa Committed by Scott Cyphers

[MLIR] Remove Fake Instructions (#3568)

* WIP

* Fixes

* Increase bitwidth for arg idx attrib

* Minor fixes

* style-apply
parent 8757f8e3
......@@ -194,11 +194,6 @@ void MLIRCompiler::run(std::vector<void*>& external_tensors)
cleanup();
}
unsigned MLIRCompiler::get_mem_mgr_arg_id(mlir::FuncOp& func)
{
return func.getNumArguments() - 1;
}
// Creates an MLIR module and function with nGraph dialect ops from the input CompiledKernel.
void MLIRCompiler::build_ng_dialect_module()
{
......@@ -708,16 +703,6 @@ void MLIRCompiler::bind_arguments(std::vector<void*>& external_tensors)
{
((mlir::StaticFloatMemRef*)m_invoke_args[i])->data = (float*)(*m_external_tensors)[i];
}
// Add pointer to memory manager
// malloc here since that's what allocateMemRefArguments use
// TODO (nmostafa): Better way of doing this ? Use builder allocator ?
MLIRMemMgr** mem_mgr_arg = reinterpret_cast<MLIRMemMgr**>(malloc(sizeof(void*)));
NGRAPH_CHECK(mem_mgr_arg != nullptr);
*mem_mgr_arg = &get_mem_mgr();
// inserting memory manager ptr in right location ?
NGRAPH_CHECK(m_invoke_args.size() == get_mem_mgr_arg_id(func));
m_invoke_args.push_back(static_cast<void*>(mem_mgr_arg));
}
// Lowers standard dialect to LLVM dialect and uses the MLIR execution engine to execute the code.
......@@ -744,9 +729,6 @@ void MLIRCompiler::cleanup()
{
m_builder.reset(nullptr);
}
// Free allocated memory for JIT'ed code temps
m_mem_mgr.freeAll();
}
SmallVector<void*, 8> MLIRCompiler::allocate_memref_args()
......
......@@ -79,11 +79,6 @@ namespace ngraph
/// Executes a pre-compiled subgraph
void run(std::vector<void*>& external_tensors);
/// Returns the memory manager used by this sub-graph compiler.
MLIRMemMgr& get_mem_mgr() { return m_mem_mgr; }
/// Returns memory manager pointer argument ID in call interface.
unsigned get_mem_mgr_arg_id(mlir::FuncOp& func);
private:
struct TensorInfo
{
......@@ -172,9 +167,6 @@ namespace ngraph
TensorToInfoMap m_tensor_to_value_map;
static const MLIRCompOpMap op_dispatcher;
// Memory manager for temp allocations inside JIT'ed code
MLIRMemMgr m_mem_mgr;
// Optimization level used by MLIR and LLVM compilers.
static unsigned mlir_opt_level;
......
......@@ -67,10 +67,6 @@ class NG_Op<string mnemonic, list<OpTrait> traits = []> :
class NG_OneResult_Op<string mnemonic, list<OpTrait> traits = []> :
NG_Op<mnemonic, traits>, Results<(outs NG_TensorType:$res)> {}
// Base for fake instructions defining MemRef values
class NG_MemRefDef_Op<string mnemonic, list<OpTrait> traits = []> :
NG_Op<mnemonic, traits>, Results<(outs NG_MemRefType:$res)> {}
// Operations producing no results
class NG_ZeroResult_Op<string mnemonic, list<OpTrait> traits = []> :
NG_Op<mnemonic, traits>, Results<(outs)> {}
......@@ -309,6 +305,3 @@ def NGConvolutionOp :
// Terminator Ops
def NGReturnOp : NG_Terminator_Op<"return">;
// Fake ops
def NGFakeInputOp : NG_MemRefDef_Op<"fake.input", [NoSideEffect]>;
......@@ -83,6 +83,54 @@ namespace
#include "op_lowerers.inc"
// FuncOp Conversion pattern
class FuncOpSignatureConversion : public ConversionPattern
{
public:
FuncOpSignatureConversion(MLIRContext* ctx, TypeConverter& converter)
: ConversionPattern(FuncOp::getOperationName(), 1, ctx)
, converter(converter)
{
}
/// Hook for derived classes to implement combined matching and rewriting.
PatternMatchResult matchAndRewrite(Operation* op,
ArrayRef<Value*> operands,
ConversionPatternRewriter& rewriter) const override
{
auto funcOp = cast<FuncOp>(op);
FunctionType type = funcOp.getType();
// Convert the original function arguments.
TypeConverter::SignatureConversion result(type.getNumInputs());
for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
if (failed(converter.convertSignatureArg(i, type.getInput(i), result)))
return matchFailure();
// Convert the original function results.
SmallVector<Type, 4> convertedResults;
if (failed(converter.convertTypes(type.getResults(), convertedResults)))
return matchFailure();
// Add result types as input args without mapping
result.addInputs(convertedResults);
// Create a new function with an updated signature.
auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end());
newFuncOp.setType(
FunctionType::get(result.getConvertedTypes(), {/*void*/}, funcOp.getContext()));
// Tell the rewriter to convert the region signature.
rewriter.applySignatureConversion(&newFuncOp.getBody(), result);
rewriter.replaceOp(op, llvm::None);
return matchSuccess();
}
/// The type converter to use when rewriting the signature.
TypeConverter& converter;
};
// Helpers
template <typename RedOp>
void lowerIndexReduction(Operation* op,
......@@ -138,17 +186,12 @@ namespace
void populateNGraphToAffineConversionPatterns(OwningRewritePatternList& patterns);
void findOutputValues();
void processFakeInstrs();
void insertNoAliasArgAttrs();
private:
NGraphTypeConverter typeConverter;
// Value holding mem manager passed pointer
SmallVector<Value*, 4> memMgrDefs;
// List of temporary memrefs to deallocate at end of function
SmallVector<Value*, 4> memRefsToDealloc;
// list of results values to add to func signature
SmallVector<Value*, 4> loweredOutputValues;
ngmlir::MLIRCompiler& compiler;
};
......@@ -157,17 +200,14 @@ namespace
// Create type converter and initialize conversion patterns.
NGraphTypeConverter converter;
OwningRewritePatternList patterns;
// Add default FuncOp type conversion. It replaces the incoming FuncOp with a *new* one
// with the converted types.
mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), typeConverter);
populateNGraphToAffineConversionPatterns(patterns);
// Create target that defines legal ops for nGraph dialect to be lowered to.
ConversionTarget target(getContext());
// TODO: Remove NGFakeInputOp. We need to set NGFakeInputOp as legal op because we generate
// it as part of the lowering to affine/standard.
target.addLegalDialect<AffineOpsDialect, StandardOpsDialect>();
target.addLegalOp<ModuleOp, ModuleTerminatorOp, NGFakeInputOp>();
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
// FuncOp is legal only if types have been converted to Std types.
return typeConverter.isSignatureLegal(op.getType());
......@@ -184,7 +224,6 @@ namespace
signalPassFailure();
}
processFakeInstrs();
insertNoAliasArgAttrs();
}
......@@ -196,6 +235,9 @@ namespace
patterns.insert<
#include "op_lowerers.inc"
>(&getContext(), *this);
// FuncOp pattern
patterns.insert<FuncOpSignatureConversion>(&getContext(), typeConverter);
}
void DialectLoweringPass::findOutputValues()
......@@ -204,26 +246,23 @@ namespace
auto f = getModule().lookupSymbol<mlir::FuncOp>("main");
SmallVector<Value*, 4> outputList;
unsigned outputCount = 0;
unsigned inputCount = f.getType().getNumInputs();
// we find out output values by looking at returned values
// any return should return all outputs of the subgraph
f.walk([this, &outputCount](NGReturnOp ret) {
f.walk([this, &outputCount, inputCount](NGReturnOp ret) {
for (unsigned i = 0; i < ret.getNumOperands(); i++)
{
// annotate instructions defining outputs with the arg idx of the output
auto outputValue = ret.getOperand(i);
auto op = outputValue->getDefiningOp();
op->setAttr("graphOutputIdx",
mlir::IntegerAttr::get(IntegerType::get(8, op->getContext()), i));
op->setAttr(
"graphOutputIdx",
mlir::IntegerAttr::get(IntegerType::get(32, op->getContext()), i + inputCount));
}
NGRAPH_CHECK(outputCount == 0 || outputCount == ret.getNumOperands(),
"Inconsistent returns in function");
outputCount = ret.getNumOperands();
});
// will be populated with lowered output values later
// TODO: This resize is making debugging obscure. When the container is not populated due
// to a bug, null pointers are used by the consumer leading to a crash more difficult to
// root-cause. We should try to change the current approach or introduce verification code.
loweredOutputValues.resize(outputCount, nullptr);
}
SmallVector<Value*, 4> DialectLoweringPass::buildOutputDefs(Operation* op,
......@@ -232,19 +271,13 @@ namespace
SmallVector<Value*, 4> newResults;
for (auto origResult : op->getResults())
{
// create output def if this operation produces any sub-graph outputs
// find output arg if this operation produces any sub-graph outputs
if (IntegerAttr attr = op->getAttrOfType<IntegerAttr>("graphOutputIdx"))
{
unsigned argId = (int)attr.getInt();
auto fakeOp = rewriter.create<NGFakeInputOp>(
op->getLoc(),
typeConverter.convertType(origResult->getType()) /* convert to lowered type */
);
// Fake instrution is short-lived. Verify here.
fakeOp.verify();
auto newResult = fakeOp.getResult();
newResults.push_back(newResult);
loweredOutputValues[argId] = newResult;
auto f = getModule().lookupSymbol<mlir::FuncOp>("main");
mlir::Block* entryBlock = &*(f.begin());
unsigned argId = (unsigned)attr.getInt();
newResults.push_back(entryBlock->getArgument(argId));
}
else
{
......@@ -278,52 +311,6 @@ namespace
return alloc;
}
void DialectLoweringPass::processFakeInstrs()
{
auto context = getModule().getContext();
auto f = getModule().lookupSymbol<mlir::FuncOp>("main");
mlir::Block* entryBlock = &*(f.begin());
auto oldFuncType = f.getType();
ArrayRef<mlir::Type> ipArgs = oldFuncType.getInputs();
ArrayRef<mlir::Type> opArgs = oldFuncType.getResults();
SmallVector<mlir::Type, 4> allArgs;
// Move all args as inputs in new type
for (auto type : ipArgs)
{
allArgs.push_back(type);
}
for (auto type : opArgs)
{
allArgs.push_back(type);
// add new value for result
entryBlock->addArgument(type);
}
// Mem Manager Ptr
auto indexType = mlir::IndexType::get(context);
allArgs.push_back(indexType);
entryBlock->addArgument(indexType);
// update type
auto newFuncType = mlir::FunctionType::get(allArgs, {}, context);
f.setType(newFuncType);
// RAUW fake outputs with result values
unsigned i = 0;
for (auto value : loweredOutputValues)
{
auto op = value->getDefiningOp();
NGRAPH_CHECK(isa<NGFakeInputOp>(op), "output value not defined by fake output?");
value->replaceAllUsesWith(entryBlock->getArgument(oldFuncType.getNumInputs() + i));
op->erase();
i++;
}
for (auto v : memMgrDefs)
{
v->replaceAllUsesWith(entryBlock->getArgument(compiler.get_mem_mgr_arg_id(f)));
v->getDefiningOp()->erase();
}
}
/// Add llvm.noalias attribute to all the memref function arguments. We know that this is safe
/// by nGraph op semantics.
void DialectLoweringPass::insertNoAliasArgAttrs()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment