Commit 1fb6fa96 authored by Nagy Mostafa's avatar Nagy Mostafa Committed by Scott Cyphers

[MLIR] Remove Fake Instructions (#3568)

* WIP

* Fixes

* Increase bitwidth for arg idx attrib

* Minor fixes

* style-apply
parent 8757f8e3
...@@ -194,11 +194,6 @@ void MLIRCompiler::run(std::vector<void*>& external_tensors) ...@@ -194,11 +194,6 @@ void MLIRCompiler::run(std::vector<void*>& external_tensors)
cleanup(); cleanup();
} }
unsigned MLIRCompiler::get_mem_mgr_arg_id(mlir::FuncOp& func)
{
return func.getNumArguments() - 1;
}
// Creates an MLIR module and function with nGraph dialect ops from the input CompiledKernel. // Creates an MLIR module and function with nGraph dialect ops from the input CompiledKernel.
void MLIRCompiler::build_ng_dialect_module() void MLIRCompiler::build_ng_dialect_module()
{ {
...@@ -708,16 +703,6 @@ void MLIRCompiler::bind_arguments(std::vector<void*>& external_tensors) ...@@ -708,16 +703,6 @@ void MLIRCompiler::bind_arguments(std::vector<void*>& external_tensors)
{ {
((mlir::StaticFloatMemRef*)m_invoke_args[i])->data = (float*)(*m_external_tensors)[i]; ((mlir::StaticFloatMemRef*)m_invoke_args[i])->data = (float*)(*m_external_tensors)[i];
} }
// Add pointer to memory manager
// malloc here since that's what allocateMemRefArguments use
// TODO (nmostafa): Better way of doing this ? Use builder allocator ?
MLIRMemMgr** mem_mgr_arg = reinterpret_cast<MLIRMemMgr**>(malloc(sizeof(void*)));
NGRAPH_CHECK(mem_mgr_arg != nullptr);
*mem_mgr_arg = &get_mem_mgr();
// inserting memory manager ptr in right location ?
NGRAPH_CHECK(m_invoke_args.size() == get_mem_mgr_arg_id(func));
m_invoke_args.push_back(static_cast<void*>(mem_mgr_arg));
} }
// Lowers standard dialect to LLVM dialect and uses the MLIR execution engine to execute the code. // Lowers standard dialect to LLVM dialect and uses the MLIR execution engine to execute the code.
...@@ -744,9 +729,6 @@ void MLIRCompiler::cleanup() ...@@ -744,9 +729,6 @@ void MLIRCompiler::cleanup()
{ {
m_builder.reset(nullptr); m_builder.reset(nullptr);
} }
// Free allocated memory for JIT'ed code temps
m_mem_mgr.freeAll();
} }
SmallVector<void*, 8> MLIRCompiler::allocate_memref_args() SmallVector<void*, 8> MLIRCompiler::allocate_memref_args()
......
...@@ -79,11 +79,6 @@ namespace ngraph ...@@ -79,11 +79,6 @@ namespace ngraph
/// Executes a pre-compiled subgraph /// Executes a pre-compiled subgraph
void run(std::vector<void*>& external_tensors); void run(std::vector<void*>& external_tensors);
/// Returns the memory manager used by this sub-graph compiler.
MLIRMemMgr& get_mem_mgr() { return m_mem_mgr; }
/// Returns memory manager pointer argument ID in call interface.
unsigned get_mem_mgr_arg_id(mlir::FuncOp& func);
private: private:
struct TensorInfo struct TensorInfo
{ {
...@@ -172,9 +167,6 @@ namespace ngraph ...@@ -172,9 +167,6 @@ namespace ngraph
TensorToInfoMap m_tensor_to_value_map; TensorToInfoMap m_tensor_to_value_map;
static const MLIRCompOpMap op_dispatcher; static const MLIRCompOpMap op_dispatcher;
// Memory manager for temp allocations inside JIT'ed code
MLIRMemMgr m_mem_mgr;
// Optimization level used by MLIR and LLVM compilers. // Optimization level used by MLIR and LLVM compilers.
static unsigned mlir_opt_level; static unsigned mlir_opt_level;
......
...@@ -67,10 +67,6 @@ class NG_Op<string mnemonic, list<OpTrait> traits = []> : ...@@ -67,10 +67,6 @@ class NG_Op<string mnemonic, list<OpTrait> traits = []> :
class NG_OneResult_Op<string mnemonic, list<OpTrait> traits = []> : class NG_OneResult_Op<string mnemonic, list<OpTrait> traits = []> :
NG_Op<mnemonic, traits>, Results<(outs NG_TensorType:$res)> {} NG_Op<mnemonic, traits>, Results<(outs NG_TensorType:$res)> {}
// Base for fake instructions defining MemRef values
class NG_MemRefDef_Op<string mnemonic, list<OpTrait> traits = []> :
NG_Op<mnemonic, traits>, Results<(outs NG_MemRefType:$res)> {}
// Operations producing no results // Operations producing no results
class NG_ZeroResult_Op<string mnemonic, list<OpTrait> traits = []> : class NG_ZeroResult_Op<string mnemonic, list<OpTrait> traits = []> :
NG_Op<mnemonic, traits>, Results<(outs)> {} NG_Op<mnemonic, traits>, Results<(outs)> {}
...@@ -309,6 +305,3 @@ def NGConvolutionOp : ...@@ -309,6 +305,3 @@ def NGConvolutionOp :
// Terminator Ops // Terminator Ops
def NGReturnOp : NG_Terminator_Op<"return">; def NGReturnOp : NG_Terminator_Op<"return">;
// Fake ops
def NGFakeInputOp : NG_MemRefDef_Op<"fake.input", [NoSideEffect]>;
...@@ -83,6 +83,54 @@ namespace ...@@ -83,6 +83,54 @@ namespace
#include "op_lowerers.inc" #include "op_lowerers.inc"
// FuncOp Conversion pattern
class FuncOpSignatureConversion : public ConversionPattern
{
public:
FuncOpSignatureConversion(MLIRContext* ctx, TypeConverter& converter)
: ConversionPattern(FuncOp::getOperationName(), 1, ctx)
, converter(converter)
{
}
/// Hook for derived classes to implement combined matching and rewriting.
PatternMatchResult matchAndRewrite(Operation* op,
ArrayRef<Value*> operands,
ConversionPatternRewriter& rewriter) const override
{
auto funcOp = cast<FuncOp>(op);
FunctionType type = funcOp.getType();
// Convert the original function arguments.
TypeConverter::SignatureConversion result(type.getNumInputs());
for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
if (failed(converter.convertSignatureArg(i, type.getInput(i), result)))
return matchFailure();
// Convert the original function results.
SmallVector<Type, 4> convertedResults;
if (failed(converter.convertTypes(type.getResults(), convertedResults)))
return matchFailure();
// Add result types as input args without mapping
result.addInputs(convertedResults);
// Create a new function with an updated signature.
auto newFuncOp = rewriter.cloneWithoutRegions(funcOp);
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end());
newFuncOp.setType(
FunctionType::get(result.getConvertedTypes(), {/*void*/}, funcOp.getContext()));
// Tell the rewriter to convert the region signature.
rewriter.applySignatureConversion(&newFuncOp.getBody(), result);
rewriter.replaceOp(op, llvm::None);
return matchSuccess();
}
/// The type converter to use when rewriting the signature.
TypeConverter& converter;
};
// Helpers // Helpers
template <typename RedOp> template <typename RedOp>
void lowerIndexReduction(Operation* op, void lowerIndexReduction(Operation* op,
...@@ -138,17 +186,12 @@ namespace ...@@ -138,17 +186,12 @@ namespace
void populateNGraphToAffineConversionPatterns(OwningRewritePatternList& patterns); void populateNGraphToAffineConversionPatterns(OwningRewritePatternList& patterns);
void findOutputValues(); void findOutputValues();
void processFakeInstrs();
void insertNoAliasArgAttrs(); void insertNoAliasArgAttrs();
private: private:
NGraphTypeConverter typeConverter; NGraphTypeConverter typeConverter;
// Value holding mem manager passed pointer
SmallVector<Value*, 4> memMgrDefs;
// List of temporary memrefs to deallocate at end of function // List of temporary memrefs to deallocate at end of function
SmallVector<Value*, 4> memRefsToDealloc; SmallVector<Value*, 4> memRefsToDealloc;
// list of results values to add to func signature
SmallVector<Value*, 4> loweredOutputValues;
ngmlir::MLIRCompiler& compiler; ngmlir::MLIRCompiler& compiler;
}; };
...@@ -157,17 +200,14 @@ namespace ...@@ -157,17 +200,14 @@ namespace
// Create type converter and initialize conversion patterns. // Create type converter and initialize conversion patterns.
NGraphTypeConverter converter; NGraphTypeConverter converter;
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
// Add default FuncOp type conversion. It replaces the incoming FuncOp with a *new* one
// with the converted types.
mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), typeConverter);
populateNGraphToAffineConversionPatterns(patterns); populateNGraphToAffineConversionPatterns(patterns);
// Create target that defines legal ops for nGraph dialect to be lowered to. // Create target that defines legal ops for nGraph dialect to be lowered to.
ConversionTarget target(getContext()); ConversionTarget target(getContext());
// TODO: Remove NGFakeInputOp. We need to set NGFakeInputOp as legal op because we generate
// it as part of the lowering to affine/standard.
target.addLegalDialect<AffineOpsDialect, StandardOpsDialect>(); target.addLegalDialect<AffineOpsDialect, StandardOpsDialect>();
target.addLegalOp<ModuleOp, ModuleTerminatorOp, NGFakeInputOp>(); target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
// FuncOp is legal only if types have been converted to Std types. // FuncOp is legal only if types have been converted to Std types.
return typeConverter.isSignatureLegal(op.getType()); return typeConverter.isSignatureLegal(op.getType());
...@@ -184,7 +224,6 @@ namespace ...@@ -184,7 +224,6 @@ namespace
signalPassFailure(); signalPassFailure();
} }
processFakeInstrs();
insertNoAliasArgAttrs(); insertNoAliasArgAttrs();
} }
...@@ -196,6 +235,9 @@ namespace ...@@ -196,6 +235,9 @@ namespace
patterns.insert< patterns.insert<
#include "op_lowerers.inc" #include "op_lowerers.inc"
>(&getContext(), *this); >(&getContext(), *this);
// FuncOp pattern
patterns.insert<FuncOpSignatureConversion>(&getContext(), typeConverter);
} }
void DialectLoweringPass::findOutputValues() void DialectLoweringPass::findOutputValues()
...@@ -204,26 +246,23 @@ namespace ...@@ -204,26 +246,23 @@ namespace
auto f = getModule().lookupSymbol<mlir::FuncOp>("main"); auto f = getModule().lookupSymbol<mlir::FuncOp>("main");
SmallVector<Value*, 4> outputList; SmallVector<Value*, 4> outputList;
unsigned outputCount = 0; unsigned outputCount = 0;
unsigned inputCount = f.getType().getNumInputs();
// we find out output values by looking at returned values // we find out output values by looking at returned values
// any return should return all outputs of the subgraph // any return should return all outputs of the subgraph
f.walk([this, &outputCount](NGReturnOp ret) { f.walk([this, &outputCount, inputCount](NGReturnOp ret) {
for (unsigned i = 0; i < ret.getNumOperands(); i++) for (unsigned i = 0; i < ret.getNumOperands(); i++)
{ {
// annotate instructions defining outputs with the arg idx of the output
auto outputValue = ret.getOperand(i); auto outputValue = ret.getOperand(i);
auto op = outputValue->getDefiningOp(); auto op = outputValue->getDefiningOp();
op->setAttr("graphOutputIdx",
mlir::IntegerAttr::get(IntegerType::get(8, op->getContext()), i)); op->setAttr(
"graphOutputIdx",
mlir::IntegerAttr::get(IntegerType::get(32, op->getContext()), i + inputCount));
} }
NGRAPH_CHECK(outputCount == 0 || outputCount == ret.getNumOperands(), NGRAPH_CHECK(outputCount == 0 || outputCount == ret.getNumOperands(),
"Inconsistent returns in function"); "Inconsistent returns in function");
outputCount = ret.getNumOperands();
}); });
// will be populated with lowered output values later
// TODO: This resize is making debugging obscure. When the container is not populated due
// to a bug, null pointers are used by the consumer leading to a crash more difficult to
// root-cause. We should try to change the current approach or introduce verification code.
loweredOutputValues.resize(outputCount, nullptr);
} }
SmallVector<Value*, 4> DialectLoweringPass::buildOutputDefs(Operation* op, SmallVector<Value*, 4> DialectLoweringPass::buildOutputDefs(Operation* op,
...@@ -232,19 +271,13 @@ namespace ...@@ -232,19 +271,13 @@ namespace
SmallVector<Value*, 4> newResults; SmallVector<Value*, 4> newResults;
for (auto origResult : op->getResults()) for (auto origResult : op->getResults())
{ {
// create output def if this operation produces any sub-graph outputs // find output arg if this operation produces any sub-graph outputs
if (IntegerAttr attr = op->getAttrOfType<IntegerAttr>("graphOutputIdx")) if (IntegerAttr attr = op->getAttrOfType<IntegerAttr>("graphOutputIdx"))
{ {
unsigned argId = (int)attr.getInt(); auto f = getModule().lookupSymbol<mlir::FuncOp>("main");
auto fakeOp = rewriter.create<NGFakeInputOp>( mlir::Block* entryBlock = &*(f.begin());
op->getLoc(), unsigned argId = (unsigned)attr.getInt();
typeConverter.convertType(origResult->getType()) /* convert to lowered type */ newResults.push_back(entryBlock->getArgument(argId));
);
// Fake instrution is short-lived. Verify here.
fakeOp.verify();
auto newResult = fakeOp.getResult();
newResults.push_back(newResult);
loweredOutputValues[argId] = newResult;
} }
else else
{ {
...@@ -278,52 +311,6 @@ namespace ...@@ -278,52 +311,6 @@ namespace
return alloc; return alloc;
} }
void DialectLoweringPass::processFakeInstrs()
{
auto context = getModule().getContext();
auto f = getModule().lookupSymbol<mlir::FuncOp>("main");
mlir::Block* entryBlock = &*(f.begin());
auto oldFuncType = f.getType();
ArrayRef<mlir::Type> ipArgs = oldFuncType.getInputs();
ArrayRef<mlir::Type> opArgs = oldFuncType.getResults();
SmallVector<mlir::Type, 4> allArgs;
// Move all args as inputs in new type
for (auto type : ipArgs)
{
allArgs.push_back(type);
}
for (auto type : opArgs)
{
allArgs.push_back(type);
// add new value for result
entryBlock->addArgument(type);
}
// Mem Manager Ptr
auto indexType = mlir::IndexType::get(context);
allArgs.push_back(indexType);
entryBlock->addArgument(indexType);
// update type
auto newFuncType = mlir::FunctionType::get(allArgs, {}, context);
f.setType(newFuncType);
// RAUW fake outputs with result values
unsigned i = 0;
for (auto value : loweredOutputValues)
{
auto op = value->getDefiningOp();
NGRAPH_CHECK(isa<NGFakeInputOp>(op), "output value not defined by fake output?");
value->replaceAllUsesWith(entryBlock->getArgument(oldFuncType.getNumInputs() + i));
op->erase();
i++;
}
for (auto v : memMgrDefs)
{
v->replaceAllUsesWith(entryBlock->getArgument(compiler.get_mem_mgr_arg_id(f)));
v->getDefiningOp()->erase();
}
}
/// Add llvm.noalias attribute to all the memref function arguments. We know that this is safe /// Add llvm.noalias attribute to all the memref function arguments. We know that this is safe
/// by nGraph op semantics. /// by nGraph op semantics.
void DialectLoweringPass::insertNoAliasArgAttrs() void DialectLoweringPass::insertNoAliasArgAttrs()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment