Commit 1d5d2024 authored by Nagy Mostafa's avatar Nagy Mostafa Committed by Scott Cyphers

[MLIR] Deallocate all temp tensors before return (#3289)

* Deallocate all temp tensors before return

* style-apply
parent c0edf94d
...@@ -122,6 +122,9 @@ namespace ...@@ -122,6 +122,9 @@ namespace
ArrayRef<Type> output, ArrayRef<Type> output,
PatternRewriter& rewriter); PatternRewriter& rewriter);
/// Inserts dealloc Ops for each temporary allocated by AllocOp
void insertDeallocs(PatternRewriter& rewriter);
private: private:
/// Collect a set of patterns to convert from the nGraph dialect to Affine dialect. /// Collect a set of patterns to convert from the nGraph dialect to Affine dialect.
void populateNGraphToAffineConversionPatterns(OwningRewritePatternList& patterns); void populateNGraphToAffineConversionPatterns(OwningRewritePatternList& patterns);
...@@ -134,8 +137,9 @@ namespace ...@@ -134,8 +137,9 @@ namespace
private: private:
NGraphTypeConverter typeConverter; NGraphTypeConverter typeConverter;
// Value holding mem manager passed pointer // Value holding mem manager passed pointer
SmallVector<Value*, 4> memMgrDefs; SmallVector<Value*, 4> m_memMgrDefs;
// List of temporary memrefs to deallocate at end of function
SmallVector<Value*, 4> m_memRefsToDealloc;
// list of results values to add to func signature // list of results values to add to func signature
SmallVector<Value*, 4> loweredOutputValues; SmallVector<Value*, 4> loweredOutputValues;
ngmlir::MLIRCompiler& compiler; ngmlir::MLIRCompiler& compiler;
...@@ -167,7 +171,6 @@ namespace ...@@ -167,7 +171,6 @@ namespace
} }
processFakeInstrs(); processFakeInstrs();
insertNoAliasArgAttrs(); insertNoAliasArgAttrs();
} }
...@@ -261,6 +264,7 @@ namespace ...@@ -261,6 +264,7 @@ namespace
NGRAPH_CHECK(memRefType.hasStaticShape(), "Dynamic shapes are not supported"); NGRAPH_CHECK(memRefType.hasStaticShape(), "Dynamic shapes are not supported");
Value* alloc = rewriter.create<mlir::AllocOp>(rewriter.getUnknownLoc(), memRefType); Value* alloc = rewriter.create<mlir::AllocOp>(rewriter.getUnknownLoc(), memRefType);
m_memRefsToDealloc.push_back(alloc);
// TODO: // TODO:
// Enable dynamic memref allocation via call-back to nGraph allocator // Enable dynamic memref allocation via call-back to nGraph allocator
...@@ -336,6 +340,14 @@ namespace ...@@ -336,6 +340,14 @@ namespace
} }
} }
void DialectLoweringPass::insertDeallocs(PatternRewriter& rewriter)
{
for (auto value : m_memRefsToDealloc)
{
rewriter.create<DeallocOp>(rewriter.getUnknownLoc(), value);
}
}
mlir::Function* DialectLoweringPass::getCallDecl(StringRef name, mlir::Function* DialectLoweringPass::getCallDecl(StringRef name,
ArrayRef<Type> args, ArrayRef<Type> args,
ArrayRef<Type> output, ArrayRef<Type> output,
...@@ -765,6 +777,7 @@ namespace ...@@ -765,6 +777,7 @@ namespace
REWRITER(NGReturnOp) REWRITER(NGReturnOp)
{ {
m_pass.insertDeallocs(rewriter);
rewriter.replaceOpWithNewOp<ReturnOp>(op); rewriter.replaceOpWithNewOp<ReturnOp>(op);
return matchSuccess(); return matchSuccess();
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment