Commit 9b748d2c authored by Nagy Mostafa's avatar Nagy Mostafa Committed by Scott Cyphers

[MLIR] Fix incorrect callback lookup (#3233)

* Fix incorrect callback lookup

* Use std.AllocOp(malloc) for static tensor allocations

* Clean up

* revert getCallDecl change

* style
parent 928dfde3
......@@ -106,7 +106,7 @@ namespace
void runOnModule() override;
SmallVector<Value*, 4> buildOutputDefs(Operation* op, PatternRewriter& rewriter);
Value* createTempTensor(Type type, unsigned size, PatternRewriter& rewriter);
Value* createTempTensor(Type type, PatternRewriter& rewriter);
mlir::Function* getCallDecl(StringRef name,
ArrayRef<Type> args,
......@@ -235,28 +235,31 @@ namespace
else
{
auto tensorType = origResult->getType().cast<NGTensorType>();
auto newResult = createTempTensor(
m_typeConverter.convertType(tensorType), tensorType.getSizeInBytes(), rewriter);
auto newResult =
createTempTensor(m_typeConverter.convertType(tensorType), rewriter);
newResults.push_back(newResult);
}
}
return newResults;
}
Value*
DialectLoweringPass::createTempTensor(Type type, unsigned size, PatternRewriter& rewriter)
Value* DialectLoweringPass::createTempTensor(Type type, PatternRewriter& rewriter)
{
auto callBackFunc = getCallDecl("__mlir_allocate",
{rewriter.getIndexType(), rewriter.getIndexType()},
{type},
rewriter);
SmallVector<mlir::Value*, 4> args = {
insertMemMgrDef(&rewriter), /* pointer to mem manager */
rewriter.create<mlir::ConstantIndexOp>(rewriter.getUnknownLoc(),
size)}; /* size to allocate */
auto newTemp = rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args)
.getResult(0);
return newTemp;
MemRefType memRefType = type.cast<MemRefType>();
NGRAPH_CHECK(memRefType.hasStaticShape(), "Dynamic shapes are not supported");
Value* alloc = rewriter.create<mlir::AllocOp>(rewriter.getUnknownLoc(), memRefType);
// TODO:
// Enable dynamic memref allocation via call-back to nGraph allocator
// We should create a list of Values representing each dynamic dim
// The values would be computed based on the shape of the input to the ng op we are lowering.
// E.g. If lowering concat, Value for dynamic concat axis will be the sum of input dims.
// The lowerer will generate code to compute the dims.
// This is better be done via std.AllocOp but we need to make it hookable to nGraph allocator call-back.
return alloc;
}
void DialectLoweringPass::processFakeInstrs()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment