Commit 9b748d2c authored by Nagy Mostafa's avatar Nagy Mostafa Committed by Scott Cyphers

[MLIR] Fix incorrect callback lookup (#3233)

* Fix incorrect callback lookup

* Use std.AllocOp(malloc) for static tensor allocations

* Clean up

* revert getCallDecl change

* style
parent 928dfde3
...@@ -106,7 +106,7 @@ namespace ...@@ -106,7 +106,7 @@ namespace
void runOnModule() override; void runOnModule() override;
SmallVector<Value*, 4> buildOutputDefs(Operation* op, PatternRewriter& rewriter); SmallVector<Value*, 4> buildOutputDefs(Operation* op, PatternRewriter& rewriter);
Value* createTempTensor(Type type, unsigned size, PatternRewriter& rewriter); Value* createTempTensor(Type type, PatternRewriter& rewriter);
mlir::Function* getCallDecl(StringRef name, mlir::Function* getCallDecl(StringRef name,
ArrayRef<Type> args, ArrayRef<Type> args,
...@@ -235,28 +235,31 @@ namespace ...@@ -235,28 +235,31 @@ namespace
else else
{ {
auto tensorType = origResult->getType().cast<NGTensorType>(); auto tensorType = origResult->getType().cast<NGTensorType>();
auto newResult = createTempTensor( auto newResult =
m_typeConverter.convertType(tensorType), tensorType.getSizeInBytes(), rewriter); createTempTensor(m_typeConverter.convertType(tensorType), rewriter);
newResults.push_back(newResult); newResults.push_back(newResult);
} }
} }
return newResults; return newResults;
} }
Value* Value* DialectLoweringPass::createTempTensor(Type type, PatternRewriter& rewriter)
DialectLoweringPass::createTempTensor(Type type, unsigned size, PatternRewriter& rewriter)
{ {
auto callBackFunc = getCallDecl("__mlir_allocate", MemRefType memRefType = type.cast<MemRefType>();
{rewriter.getIndexType(), rewriter.getIndexType()},
{type}, NGRAPH_CHECK(memRefType.hasStaticShape(), "Dynamic shapes are not supported");
rewriter);
SmallVector<mlir::Value*, 4> args = { Value* alloc = rewriter.create<mlir::AllocOp>(rewriter.getUnknownLoc(), memRefType);
insertMemMgrDef(&rewriter), /* pointer to mem manager */
rewriter.create<mlir::ConstantIndexOp>(rewriter.getUnknownLoc(), // TODO:
size)}; /* size to allocate */ // Enable dynamic memref allocation via call-back to nGraph allocator
auto newTemp = rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args) // We should create a list of Values representing each dynamic dim
.getResult(0); // The values would be computed based on the shape of the input to the ng op we are lowering.
return newTemp; // E.g. If lowering concat, Value for dynamic concat axis will be the sum of input dims.
// The lowerer will generate code to compute the dims.
// This is better be done via std.AllocOp but we need to make it hookable to nGraph allocator call-back.
return alloc;
} }
void DialectLoweringPass::processFakeInstrs() void DialectLoweringPass::processFakeInstrs()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment