Unverified Commit a8a9bcb5 authored by Amy Zhuang's avatar Amy Zhuang Committed by GitHub

[MLIR] Use mkldnn callback for ConvBias. (#4205)

* [MLIR] Use mkldnn callback for ConvBias.

* Add try catch.

Fix opAttrsVec.

Add rank check for Gemm and MatMul.

* Fix merge error.

* Fix a bug.

* Fix lit test.

* Modify unit test.

* Fix merge error.

* Address PR feedback.

* Address PR feedback.

* Insert callback_init function to module.

* Fix lit tests.

* Fix a bug.

* Use a set of GlobalOps for attributes.

* Address PR feedback.

* Address PR feedback.

* Fix merge error.

* Fix style error.

* Fix style error.
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
parent 3dce6fdb
......@@ -30,6 +30,7 @@
#include <llvm/Support/Debug.h>
#include <mlir/Dialect/AffineOps/EDSC/Builders.h>
#include <mlir/Dialect/AffineOps/EDSC/Intrinsics.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/Dialect/StandardOps/EDSC/Intrinsics.h>
#include <mlir/IR/AffineExpr.h>
#include <mlir/IR/Function.h>
......@@ -250,6 +251,63 @@ namespace
}
};
// Return llvm type for given attributes type
static LLVM::LLVMType getLLVMType(AttrsType attrsType, LLVM::LLVMDialect* llvmDialect)
{
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
auto llvmI8Ty = LLVM::LLVMType::getInt8Ty(llvmDialect);
auto llvmArray1DI64Ty = LLVM::LLVMType::getArrayTy(llvmI64Ty, 1);
auto llvmArray2DI64Ty = LLVM::LLVMType::getArrayTy(llvmI64Ty, 2);
auto llvmArray3DI64Ty = LLVM::LLVMType::getArrayTy(llvmI64Ty, 3);
auto llvmF32Ty = LLVM::LLVMType::getFloatTy(llvmDialect);
switch (attrsType)
{
case AttrsType::INT: return llvmI64Ty;
case AttrsType::CONV1D:
return LLVM::LLVMType::getStructTy(
llvmDialect,
{llvmI8Ty, llvmArray1DI64Ty, llvmArray1DI64Ty, llvmArray1DI64Ty, llvmArray1DI64Ty});
case AttrsType::CONV2D:
return LLVM::LLVMType::getStructTy(
llvmDialect,
{llvmI8Ty, llvmArray2DI64Ty, llvmArray2DI64Ty, llvmArray2DI64Ty, llvmArray2DI64Ty});
case AttrsType::CONV3D:
return LLVM::LLVMType::getStructTy(
llvmDialect,
{llvmI8Ty, llvmArray3DI64Ty, llvmArray3DI64Ty, llvmArray3DI64Ty, llvmArray3DI64Ty});
case AttrsType::POOL2D:
return LLVM::LLVMType::getStructTy(
llvmDialect,
{llvmI8Ty, llvmArray2DI64Ty, llvmArray2DI64Ty, llvmArray2DI64Ty, llvmArray2DI64Ty});
case AttrsType::POOL3D:
return LLVM::LLVMType::getStructTy(
llvmDialect,
{llvmI8Ty, llvmArray3DI64Ty, llvmArray3DI64Ty, llvmArray3DI64Ty, llvmArray3DI64Ty});
case AttrsType::GEMM:
return LLVM::LLVMType::getStructTy(llvmDialect,
{llvmI8Ty,
llvmI8Ty,
llvmI64Ty,
llvmI64Ty,
llvmI64Ty,
llvmI64Ty,
llvmI64Ty,
llvmI64Ty,
llvmF32Ty,
llvmF32Ty,
llvmI32Ty});
}
}
// Create a Constant op and a Store op which stores the Constant
static void
createStore(LLVM::LLVMType llvmTy, Attribute valAttr, LLVM::GEPOp gep, OpBuilder& builder)
{
auto valueOp = builder.create<LLVM::ConstantOp>(builder.getUnknownLoc(), llvmTy, valAttr);
builder.create<LLVM::StoreOp>(builder.getUnknownLoc(), valueOp, gep);
}
/// Dialect Lowering Pass to affine ops
class DialectLoweringPass : public ModulePass<DialectLoweringPass>
{
......@@ -276,7 +334,20 @@ namespace
ArrayRef<Type> output,
PatternRewriter& rewriter);
inline size_t insertAttrs(opAttrs attrs);
// Return a GlobalOp with the given name
// If such GlobalOp does not exist, create one
mlir::LLVM::GlobalOp getGlobalOp(StringRef name,
LLVM::LLVMType globalType,
bool isConstant,
LLVM::Linkage linkageType,
Attribute initVal,
OpBuilder& rewriter);
/// Insert a function to the module which initializes the global variables
/// that hold the attributes information for callbacks.
void insertInitFunc();
inline int32_t insertAttrs(opAttrs attrs, AttrsType type);
MemoryAnalysis* getMemAnalysis() const { return m_memAnalysis; }
private:
......@@ -300,6 +371,7 @@ namespace
// Store the attributes needed by callback
std::vector<opAttrs> m_attrsVec;
std::vector<AttrsType> m_attrsTyVec;
};
void DialectLoweringPass::runOnModule()
......@@ -316,7 +388,7 @@ namespace
// Create target that defines legal ops for nGraph dialect to be lowered to.
ConversionTarget target(getContext());
target.addLegalDialect<AffineOpsDialect, StandardOpsDialect>();
target.addLegalDialect<AffineOpsDialect, StandardOpsDialect, LLVM::LLVMDialect>();
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
// FuncOp is legal only if types have been converted to Std types.
......@@ -353,7 +425,7 @@ namespace
}
}
opAttrsVec = m_attrsVec;
insertInitFunc();
}
void DialectLoweringPass::populateNGraphToAffineConversionPatterns(
......@@ -570,9 +642,514 @@ namespace
return callBackFunc;
}
inline size_t DialectLoweringPass::insertAttrs(opAttrs attrs)
mlir::LLVM::GlobalOp DialectLoweringPass::getGlobalOp(StringRef name,
LLVM::LLVMType globalType,
bool isConstant,
LLVM::Linkage linkageType,
Attribute initVal,
OpBuilder& rewriter)
{
auto module = getModule();
auto globalVal = module.lookupSymbol<LLVM::GlobalOp>(name);
if (!globalVal)
{
// Create a global and insert to the module.
PatternRewriter::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
rewriter.create<LLVM::GlobalOp>(
rewriter.getUnknownLoc(), globalType, isConstant, linkageType, name, initVal);
}
return module.lookupSymbol<LLVM::GlobalOp>(name);
}
// Attribute is int64_t
static void initINT(LLVM::LLVMDialect* llvmDialect,
int64_t intAttr,
LLVM::AddressOfOp globalPtr,
OpBuilder& builder)
{
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
auto castOp =
builder.create<LLVM::BitcastOp>(builder.getUnknownLoc(),
getLLVMType(AttrsType::INT, llvmDialect).getPointerTo(),
globalPtr);
auto intOp = builder.create<LLVM::ConstantOp>(
builder.getUnknownLoc(), llvmI64Ty, builder.getI64IntegerAttr(intAttr));
builder.create<LLVM::StoreOp>(builder.getUnknownLoc(), intOp, castOp);
}
/*
template <int N>
struct convAttrs
{
bool withRelu;
int64_t windowStrides[N];
int64_t windowDilation[N];
int64_t padBelow[N];
int64_t padAbove[N];
};
*/
static void initCONV1D(LLVM::LLVMDialect* llvmDialect,
convAttrs<1>& convAttrs1d,
SmallVector<LLVM::ConstantOp, 12>& constants,
LLVM::AddressOfOp globalPtr,
OpBuilder& builder)
{
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
auto llvmI64PtrTy = llvmI64Ty.getPointerTo();
auto llvmI8Ty = LLVM::LLVMType::getInt8Ty(llvmDialect);
auto llvmArray1DI64Ty = LLVM::LLVMType::getArrayTy(llvmI64Ty, 1);
auto conv1dTy = getLLVMType(AttrsType::CONV1D, llvmDialect);
auto castOp = builder.create<LLVM::BitcastOp>(
builder.getUnknownLoc(), conv1dTy.getPointerTo(), globalPtr);
SmallVector<LLVM::GEPOp, 6> geps;
SmallVector<LLVM::LLVMType, 6> elemsTy{
llvmI8Ty, llvmArray1DI64Ty, llvmArray1DI64Ty, llvmArray1DI64Ty, llvmArray1DI64Ty};
for (auto j = 0; j < 5; j++)
{
auto gepConv1dOp =
builder.create<LLVM::GEPOp>(builder.getUnknownLoc(),
elemsTy[j].getPointerTo(),
castOp,
ArrayRef<Value>({constants[0], constants[j]}));
geps.push_back(gepConv1dOp);
}
// Store attribute values
createStore(llvmI8Ty,
builder.getI8IntegerAttr(static_cast<int8_t>(convAttrs1d.withRelu)),
geps[0],
builder);
int k = 1;
for (auto& convAttr : {convAttrs1d.windowStrides[0],
convAttrs1d.windowDilation[0],
convAttrs1d.padBelow[0],
convAttrs1d.padAbove[0]})
{
auto gepStructOp =
builder.create<LLVM::GEPOp>(builder.getUnknownLoc(),
llvmI64PtrTy,
geps[k],
ArrayRef<Value>({constants[0], constants[0]}));
createStore(llvmI64Ty, builder.getI64IntegerAttr(convAttr), gepStructOp, builder);
k++;
}
}
static void initCONV2D(LLVM::LLVMDialect* llvmDialect,
convAttrs<2>& convAttrs2d,
SmallVector<LLVM::ConstantOp, 12>& constants,
LLVM::AddressOfOp globalPtr,
OpBuilder& builder)
{
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
auto llvmI64PtrTy = llvmI64Ty.getPointerTo();
auto llvmI8Ty = LLVM::LLVMType::getInt8Ty(llvmDialect);
auto llvmArray2DI64Ty = LLVM::LLVMType::getArrayTy(llvmI64Ty, 2);
auto conv2dTy = getLLVMType(AttrsType::CONV2D, llvmDialect);
auto castOp = builder.create<LLVM::BitcastOp>(
builder.getUnknownLoc(), conv2dTy.getPointerTo(), globalPtr);
SmallVector<LLVM::GEPOp, 6> geps;
SmallVector<LLVM::LLVMType, 6> elemsTy{
llvmI8Ty, llvmArray2DI64Ty, llvmArray2DI64Ty, llvmArray2DI64Ty, llvmArray2DI64Ty};
for (auto j = 0; j < 5; j++)
{
auto gepConv2dOp =
builder.create<LLVM::GEPOp>(builder.getUnknownLoc(),
elemsTy[j].getPointerTo(),
castOp,
ArrayRef<Value>({constants[0], constants[j]}));
geps.push_back(gepConv2dOp);
}
// Store attribute values
createStore(llvmI8Ty,
builder.getI8IntegerAttr(static_cast<int8_t>(convAttrs2d.withRelu)),
geps[0],
builder);
int k = 1, m = 0;
for (auto& convAttr : {convAttrs2d.windowStrides[0],
convAttrs2d.windowStrides[1],
convAttrs2d.windowDilation[0],
convAttrs2d.windowDilation[1],
convAttrs2d.padBelow[0],
convAttrs2d.padBelow[1],
convAttrs2d.padAbove[0],
convAttrs2d.padAbove[1]})
{
auto gepStructOp =
builder.create<LLVM::GEPOp>(builder.getUnknownLoc(),
llvmI64PtrTy,
geps[k],
ArrayRef<Value>({constants[0], constants[m]}));
createStore(llvmI64Ty, builder.getI64IntegerAttr(convAttr), gepStructOp, builder);
// k increments after every 2 iterations
if (m == 1)
{
k++;
}
// m be 0 or 1 alternatively
m = (m + 1) % 2;
}
}
static void initCONV3D(LLVM::LLVMDialect* llvmDialect,
convAttrs<3>& convAttrs3d,
SmallVector<LLVM::ConstantOp, 12>& constants,
LLVM::AddressOfOp globalPtr,
OpBuilder& builder)
{
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
auto llvmI64PtrTy = llvmI64Ty.getPointerTo();
auto llvmI8Ty = LLVM::LLVMType::getInt8Ty(llvmDialect);
auto llvmArray3DI64Ty = LLVM::LLVMType::getArrayTy(llvmI64Ty, 3);
auto conv3dTy = getLLVMType(AttrsType::CONV3D, llvmDialect);
auto castOp = builder.create<LLVM::BitcastOp>(
builder.getUnknownLoc(), conv3dTy.getPointerTo(), globalPtr);
SmallVector<LLVM::GEPOp, 6> geps;
SmallVector<LLVM::LLVMType, 6> elemsTy{
llvmI8Ty, llvmArray3DI64Ty, llvmArray3DI64Ty, llvmArray3DI64Ty, llvmArray3DI64Ty};
for (auto j = 0; j < 5; j++)
{
auto gepConv3dOp =
builder.create<LLVM::GEPOp>(builder.getUnknownLoc(),
elemsTy[j].getPointerTo(),
castOp,
ArrayRef<Value>({constants[0], constants[j]}));
geps.push_back(gepConv3dOp);
}
// Store attribute values
createStore(llvmI8Ty,
builder.getI8IntegerAttr(static_cast<int8_t>(convAttrs3d.withRelu)),
geps[0],
builder);
int k = 1, m = 0;
for (auto& convAttr : {convAttrs3d.windowStrides[0],
convAttrs3d.windowStrides[1],
convAttrs3d.windowStrides[2],
convAttrs3d.windowDilation[0],
convAttrs3d.windowDilation[1],
convAttrs3d.windowDilation[2],
convAttrs3d.padBelow[0],
convAttrs3d.padBelow[1],
convAttrs3d.padBelow[2],
convAttrs3d.padAbove[0],
convAttrs3d.padAbove[1],
convAttrs3d.padAbove[2]})
{
auto gepStructOp =
builder.create<LLVM::GEPOp>(builder.getUnknownLoc(),
llvmI64PtrTy,
geps[k],
ArrayRef<Value>({constants[0], constants[m]}));
createStore(llvmI64Ty, builder.getI64IntegerAttr(convAttr), gepStructOp, builder);
// k increments after every 3 iterations
if (m == 2)
{
k++;
}
// m be 0, 1, or 2 repeatedly.
m = (m + 1) % 3;
}
}
/*
template <int N>
struct poolAttrs
{
bool includePaddingInAvgComputation;
int64_t windowShape[N];
int64_t windowStrides[N];
int64_t padBelow[N];
int64_t padAbove[N];
};
*/
static void initPOOL2D(LLVM::LLVMDialect* llvmDialect,
poolAttrs<2>& poolAttrs2d,
SmallVector<LLVM::ConstantOp, 12>& constants,
LLVM::AddressOfOp globalPtr,
OpBuilder& builder)
{
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
auto llvmI64PtrTy = llvmI64Ty.getPointerTo();
auto llvmI8Ty = LLVM::LLVMType::getInt8Ty(llvmDialect);
auto llvmArray2DI64Ty = LLVM::LLVMType::getArrayTy(llvmI64Ty, 2);
auto pool2dTy = getLLVMType(AttrsType::POOL2D, llvmDialect);
auto castOp = builder.create<LLVM::BitcastOp>(
builder.getUnknownLoc(), pool2dTy.getPointerTo(), globalPtr);
SmallVector<LLVM::GEPOp, 6> geps;
SmallVector<LLVM::LLVMType, 6> elemsTy{
llvmI8Ty, llvmArray2DI64Ty, llvmArray2DI64Ty, llvmArray2DI64Ty, llvmArray2DI64Ty};
for (auto j = 0; j < 5; j++)
{
auto gepPool2dOp =
builder.create<LLVM::GEPOp>(builder.getUnknownLoc(),
elemsTy[j].getPointerTo(),
castOp,
ArrayRef<Value>({constants[0], constants[j]}));
geps.push_back(gepPool2dOp);
}
// Store attribute values
createStore(llvmI8Ty,
builder.getI8IntegerAttr(
static_cast<int8_t>(poolAttrs2d.includePaddingInAvgComputation)),
geps[0],
builder);
int k = 1, m = 0;
for (auto& poolAttr : {poolAttrs2d.windowShape[0],
poolAttrs2d.windowShape[1],
poolAttrs2d.windowStrides[0],
poolAttrs2d.windowStrides[1],
poolAttrs2d.padBelow[0],
poolAttrs2d.padBelow[1],
poolAttrs2d.padAbove[0],
poolAttrs2d.padAbove[1]})
{
auto gepStructOp =
builder.create<LLVM::GEPOp>(builder.getUnknownLoc(),
llvmI64PtrTy,
geps[k],
ArrayRef<Value>({constants[0], constants[m]}));
createStore(llvmI64Ty, builder.getI64IntegerAttr(poolAttr), gepStructOp, builder);
// k increments after every 2 iterations
if (m == 1)
{
k++;
}
// m be 0 or 1 alternatively
m = (m + 1) % 2;
}
}
static void initPOOL3D(LLVM::LLVMDialect* llvmDialect,
poolAttrs<3>& poolAttrs3d,
SmallVector<LLVM::ConstantOp, 12>& constants,
LLVM::AddressOfOp globalPtr,
OpBuilder& builder)
{
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
auto llvmI64PtrTy = llvmI64Ty.getPointerTo();
auto llvmI8Ty = LLVM::LLVMType::getInt8Ty(llvmDialect);
auto llvmArray3DI64Ty = LLVM::LLVMType::getArrayTy(llvmI64Ty, 3);
auto pool3dTy = getLLVMType(AttrsType::POOL3D, llvmDialect);
auto castOp = builder.create<LLVM::BitcastOp>(
builder.getUnknownLoc(), pool3dTy.getPointerTo(), globalPtr);
SmallVector<LLVM::GEPOp, 6> geps;
SmallVector<LLVM::LLVMType, 6> elemsTy{
llvmI8Ty, llvmArray3DI64Ty, llvmArray3DI64Ty, llvmArray3DI64Ty, llvmArray3DI64Ty};
for (auto j = 0; j < 5; j++)
{
auto gepPool3dOp =
builder.create<LLVM::GEPOp>(builder.getUnknownLoc(),
elemsTy[j].getPointerTo(),
castOp,
ArrayRef<Value>({constants[0], constants[j]}));
geps.push_back(gepPool3dOp);
}
// Store attribute values
createStore(llvmI8Ty,
builder.getI8IntegerAttr(
static_cast<int8_t>(poolAttrs3d.includePaddingInAvgComputation)),
geps[0],
builder);
int k = 1, m = 0;
for (auto& poolAttr : {poolAttrs3d.windowShape[0],
poolAttrs3d.windowShape[1],
poolAttrs3d.windowShape[2],
poolAttrs3d.windowStrides[0],
poolAttrs3d.windowStrides[1],
poolAttrs3d.windowStrides[2],
poolAttrs3d.padBelow[0],
poolAttrs3d.padBelow[1],
poolAttrs3d.padBelow[2],
poolAttrs3d.padAbove[0],
poolAttrs3d.padAbove[1],
poolAttrs3d.padAbove[2]})
{
auto gepStructOp =
builder.create<LLVM::GEPOp>(builder.getUnknownLoc(),
llvmI64PtrTy,
geps[k],
ArrayRef<Value>({constants[0], constants[m]}));
createStore(llvmI64Ty, builder.getI64IntegerAttr(poolAttr), gepStructOp, builder);
// k increments after every 3 iterations
if (m == 2)
{
k++;
}
// m be 0, 1, or 2 repeatedly
m = (m + 1) % 3;
}
}
/*
struct gemmAttrs
{
bool transposeA;
bool transposeB;
int64_t m;
int64_t n;
int64_t k;
int64_t lda;
int64_t ldb;
int64_t ldc;
float alpha;
float beta;
BroadcastType broadcastHint;
};
*/
static void initGEMM(LLVM::LLVMDialect* llvmDialect,
gemmAttrs gemmAttrs2d,
SmallVector<LLVM::ConstantOp, 12>& constants,
LLVM::AddressOfOp globalPtr,
OpBuilder& builder)
{
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
auto llvmI8Ty = LLVM::LLVMType::getInt8Ty(llvmDialect);
auto llvmF32Ty = LLVM::LLVMType::getFloatTy(llvmDialect);
auto gemmTy = getLLVMType(AttrsType::GEMM, llvmDialect);
auto castOp = builder.create<LLVM::BitcastOp>(
builder.getUnknownLoc(), gemmTy.getPointerTo(), globalPtr);
SmallVector<LLVM::GEPOp, 12> geps;
SmallVector<LLVM::LLVMType, 12> elemsTy{llvmI8Ty,
llvmI8Ty,
llvmI64Ty,
llvmI64Ty,
llvmI64Ty,
llvmI64Ty,
llvmI64Ty,
llvmI64Ty,
llvmF32Ty,
llvmF32Ty,
llvmI32Ty};
for (auto j = 0; j < 11; j++)
{
auto gepGemmOp =
builder.create<LLVM::GEPOp>(builder.getUnknownLoc(),
elemsTy[j].getPointerTo(),
castOp,
ArrayRef<Value>({constants[0], constants[j]}));
geps.push_back(gepGemmOp);
}
// Store attribute values
int k = 0;
for (auto& gemmAttr : {gemmAttrs2d.transposeA, gemmAttrs2d.transposeB})
{
createStore(llvmI8Ty,
builder.getI8IntegerAttr(static_cast<int8_t>(gemmAttr)),
geps[k],
builder);
k++;
}
for (auto& gemmAttr : {gemmAttrs2d.m,
gemmAttrs2d.n,
gemmAttrs2d.k,
gemmAttrs2d.lda,
gemmAttrs2d.ldb,
gemmAttrs2d.ldc})
{
createStore(llvmI64Ty, builder.getI64IntegerAttr(gemmAttr), geps[k], builder);
k++;
}
for (auto& gemmAttr : {gemmAttrs2d.alpha, gemmAttrs2d.beta})
{
createStore(llvmF32Ty, builder.getF32FloatAttr(gemmAttr), geps[k], builder);
k++;
}
createStore(llvmI32Ty,
builder.getI32IntegerAttr(static_cast<int32_t>(gemmAttrs2d.broadcastHint)),
geps[10],
builder);
}
void DialectLoweringPass::insertInitFunc()
{
auto module = getModule();
OpBuilder builder(module.getContext());
OpBuilder::InsertionGuard moduleInsertionGuard(builder);
builder.setInsertionPointToStart(module.getBody());
// Insert init function
auto funcTy = builder.getFunctionType({}, {});
SmallVector<NamedAttribute, 4> attributes;
auto funcOp = builder.create<mlir::FuncOp>(
builder.getUnknownLoc(), "callback_init", funcTy, attributes);
// Insert entry block
auto entry = funcOp.addEntryBlock();
builder.setInsertionPointToStart(entry);
if (m_attrsVec.size() == 0)
{
// No callbacks, just return
builder.create<mlir::ReturnOp>(builder.getUnknownLoc());
return;
}
// Insert operations into entry block
auto* llvmDialect = module.getContext()->getRegisteredDialect<mlir::LLVM::LLVMDialect>();
auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
// constants needed by gep
// LLVM requires that structure indexes be (vectors of) 32-bit integer constants.
SmallVector<LLVM::ConstantOp, 12> constants;
auto maxNumElem = 11;
for (auto i = 0; i < maxNumElem; i++)
{
auto constant = builder.create<LLVM::ConstantOp>(
builder.getUnknownLoc(), llvmI32Ty, builder.getI32IntegerAttr(i));
constants.push_back(constant);
}
auto globalType = getLLVMType(AttrsType::CONV3D, llvmDialect);
auto gepTy = getLLVMType(AttrsType::CONV3D, llvmDialect).getPointerTo();
int32_t i = 0;
for (auto attrs : m_attrsVec)
{
StringRef name = "globalAttrs" + std::to_string(i);
LLVM::GlobalOp globalVal = getGlobalOp(name,
globalType,
false,
LLVM::Linkage::Internal,
builder.getZeroAttr(globalType),
builder);
auto globalPtr = builder.create<LLVM::AddressOfOp>(builder.getUnknownLoc(), globalVal);
switch (m_attrsTyVec[i])
{
case AttrsType::INT: initINT(llvmDialect, attrs.intAttr, globalPtr, builder); break;
case AttrsType::CONV1D:
initCONV1D(llvmDialect, attrs.convAttrs1d, constants, globalPtr, builder);
break;
case AttrsType::CONV2D:
initCONV2D(llvmDialect, attrs.convAttrs2d, constants, globalPtr, builder);
break;
case AttrsType::CONV3D:
initCONV3D(llvmDialect, attrs.convAttrs3d, constants, globalPtr, builder);
break;
case AttrsType::POOL2D:
initPOOL2D(llvmDialect, attrs.poolAttrs2d, constants, globalPtr, builder);
break;
case AttrsType::POOL3D:
initPOOL3D(llvmDialect, attrs.poolAttrs3d, constants, globalPtr, builder);
break;
case AttrsType::GEMM:
initGEMM(llvmDialect, attrs.gemmAttrs2d, constants, globalPtr, builder);
break;
default: break;
}
i++;
}
builder.create<mlir::ReturnOp>(builder.getUnknownLoc());
}
inline int32_t DialectLoweringPass::insertAttrs(opAttrs attrs, AttrsType type)
{
m_attrsVec.push_back(attrs);
m_attrsTyVec.push_back(type);
return m_attrsVec.size() - 1;
}
......@@ -1080,7 +1657,7 @@ namespace
return matchSuccess();
}
// Use callback: Pooling, MatMul, Gemm, Softmax
// Use callback: Pooling, MatMul, Gemm, Softmax, ConvBias
static void castMemRef(SmallVector<mlir::Value, 4>& inputs,
SmallVector<mlir::Value, 4>& outputs,
PatternRewriter& rewriter,
......@@ -1093,6 +1670,23 @@ namespace
}
}
static LLVM::AddressOfOp
getGlobalAddr(int32_t index, PatternRewriter& rewriter, DialectLoweringPass& pass)
{
auto module = pass.getModule();
auto* llvmDialect = module.getContext()->getRegisteredDialect<mlir::LLVM::LLVMDialect>();
auto globalTy = getLLVMType(AttrsType::CONV3D, llvmDialect);
StringRef name = "globalAttrs" + std::to_string(index);
LLVM::GlobalOp globalVal = pass.getGlobalOp(name,
globalTy,
false,
LLVM::Linkage::Internal,
rewriter.getZeroAttr(globalTy),
rewriter);
auto globalPtr = rewriter.create<LLVM::AddressOfOp>(rewriter.getUnknownLoc(), globalVal);
return globalPtr;
}
REWRITER(NGAvgPoolOp)
{
lowerPooling<mlir::NGAvgPoolOp>(op, operands, rewriter, pass);
......@@ -1145,19 +1739,8 @@ namespace
(srcShape.size() == 5 && resultShape.size() == 5),
"MKLDNN pooling operation is only supported for 3D and 5D tensors");
auto int64Ty = rewriter.getIntegerType(64);
auto unrankedMemrefTy = UnrankedMemRefType::get(elemTy, 0);
SmallVector<mlir::Value, 4> inputs = {src, delta, result};
SmallVector<mlir::Value, 4> outputs;
castMemRef(inputs, outputs, rewriter, unrankedMemrefTy);
FuncOp callBackFunc = pass.getCallDecl(
"callback_2_inputs",
{unrankedMemrefTy, unrankedMemrefTy, unrankedMemrefTy, int64Ty, int64Ty},
{},
rewriter);
opAttrs attrs;
int32_t index = 0;
if (srcShape.size() == 4)
{
attrs.poolAttrs2d.includePaddingInAvgComputation = false;
......@@ -1168,6 +1751,7 @@ namespace
attrs.poolAttrs2d.padBelow[i] = padBelow[i].cast<IntegerAttr>().getInt();
attrs.poolAttrs2d.padAbove[i] = padAbove[i].cast<IntegerAttr>().getInt();
}
index = pass.insertAttrs(attrs, AttrsType::POOL2D);
}
else if (srcShape.size() == 5)
{
......@@ -1179,15 +1763,28 @@ namespace
attrs.poolAttrs3d.padBelow[i] = padBelow[i].cast<IntegerAttr>().getInt();
attrs.poolAttrs3d.padAbove[i] = padAbove[i].cast<IntegerAttr>().getInt();
}
index = pass.insertAttrs(attrs, AttrsType::POOL3D);
}
auto index = pass.insertAttrs(attrs);
auto attrsIndexArg =
rewriter.create<mlir::ConstantIntOp>(rewriter.getUnknownLoc(), index, 64);
// Get callback func
auto module = pass.getModule();
auto* llvmDialect = module.getContext()->getRegisteredDialect<mlir::LLVM::LLVMDialect>();
auto unionTy = getLLVMType(AttrsType::CONV3D, llvmDialect);
auto int64Ty = rewriter.getIntegerType(64);
auto unrankedMemrefTy = UnrankedMemRefType::get(elemTy, 0);
FuncOp callBackFunc = pass.getCallDecl(
"callback_2_inputs",
{unrankedMemrefTy, unrankedMemrefTy, unrankedMemrefTy, unionTy.getPointerTo(), int64Ty},
{},
rewriter);
// Insert call
auto globalPtr = getGlobalAddr(index, rewriter, pass);
auto opTypeArg = rewriter.create<mlir::ConstantIntOp>(
rewriter.getUnknownLoc(), static_cast<int64_t>(OpType::MAXPOOLBACKPROP), 64);
SmallVector<mlir::Value, 4> args = {
outputs[0], outputs[1], outputs[2], attrsIndexArg, opTypeArg};
SmallVector<mlir::Value, 4> inputs = {src, delta, result};
SmallVector<mlir::Value, 4> outputs;
castMemRef(inputs, outputs, rewriter, unrankedMemrefTy);
SmallVector<mlir::Value, 6> args = {
outputs[0], outputs[1], outputs[2], globalPtr, opTypeArg};
rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args);
rewriter.replaceOp(op, result);
return matchSuccess();
......@@ -1241,26 +1838,30 @@ namespace
attrs.gemmAttrs2d.n = rhsShape[0];
}
attrs.gemmAttrs2d.ldc = attrs.gemmAttrs2d.n;
attrs.gemmAttrs2d.alpha = 1.0;
attrs.gemmAttrs2d.beta = 0.0;
BroadcastType broadcastHint = BroadcastType::NONE;
auto index = pass.insertAttrs(attrs, AttrsType::GEMM);
// Get callback func
auto module = pass.getModule();
auto* llvmDialect = module.getContext()->getRegisteredDialect<mlir::LLVM::LLVMDialect>();
auto unionTy = getLLVMType(AttrsType::CONV3D, llvmDialect);
auto int64Ty = rewriter.getIntegerType(64);
auto unrankedMemrefTy = UnrankedMemRefType::get(elemTy, 0);
auto callBackFunc = pass.getCallDecl(
"callback_2_inputs",
{unrankedMemrefTy, unrankedMemrefTy, unrankedMemrefTy, int64Ty, int64Ty},
{unrankedMemrefTy, unrankedMemrefTy, unrankedMemrefTy, unionTy.getPointerTo(), int64Ty},
{},
rewriter);
auto index = pass.insertAttrs(attrs);
auto attrsIndexArg =
rewriter.create<mlir::ConstantIntOp>(rewriter.getUnknownLoc(), index, 64);
// Insert call
auto globalPtr = getGlobalAddr(index, rewriter, pass);
auto opTypeArg = rewriter.create<mlir::ConstantIntOp>(
rewriter.getUnknownLoc(), static_cast<int64_t>(OpType::MATMUL), 64);
SmallVector<mlir::Value, 4> inputs = {lhs, rhs, result};
SmallVector<mlir::Value, 4> outputs;
castMemRef(inputs, outputs, rewriter, unrankedMemrefTy);
SmallVector<mlir::Value, 4> args = {
outputs[0], outputs[1], outputs[2], attrsIndexArg, opTypeArg};
SmallVector<mlir::Value, 6> args = {
outputs[0], outputs[1], outputs[2], globalPtr, opTypeArg};
rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args);
rewriter.replaceOp(op, result);
......@@ -1358,7 +1959,11 @@ namespace
}
NGRAPH_CHECK(broadcastHint != BroadcastType::ERROR, "Unhandled broadcast");
attrs.gemmAttrs2d.broadcastHint = broadcastHint;
auto index = pass.insertAttrs(attrs, AttrsType::GEMM);
// Get callback func
auto module = pass.getModule();
auto* llvmDialect = module.getContext()->getRegisteredDialect<mlir::LLVM::LLVMDialect>();
auto unionTy = getLLVMType(AttrsType::CONV3D, llvmDialect);
auto int64Ty = rewriter.getIntegerType(64);
auto unrankedMemrefTy = UnrankedMemRefType::get(elemTy, 0);
auto callBackFunc = pass.getCallDecl("callback_3_inputs",
......@@ -1366,22 +1971,19 @@ namespace
unrankedMemrefTy,
unrankedMemrefTy,
unrankedMemrefTy,
int64Ty,
unionTy.getPointerTo(),
int64Ty},
{},
rewriter);
auto index = pass.insertAttrs(attrs);
auto attrsIndexArg =
rewriter.create<mlir::ConstantIntOp>(rewriter.getUnknownLoc(), index, 64);
// Insert call
auto globalPtr = getGlobalAddr(index, rewriter, pass);
auto opTypeArg = rewriter.create<mlir::ConstantIntOp>(
rewriter.getUnknownLoc(), static_cast<int64_t>(OpType::GEMM), 64);
SmallVector<mlir::Value, 4> inputs = {lhs, rhs, bias, result};
SmallVector<mlir::Value, 4> outputs;
castMemRef(inputs, outputs, rewriter, unrankedMemrefTy);
SmallVector<mlir::Value, 4> args = {
outputs[0], outputs[1], outputs[2], outputs[3], attrsIndexArg, opTypeArg};
SmallVector<mlir::Value, 6> args = {
outputs[0], outputs[1], outputs[2], outputs[3], globalPtr, opTypeArg};
rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args);
rewriter.replaceOp(op, result);
......@@ -1413,28 +2015,126 @@ namespace
(lhsShape.size() == 4 && resultShape.size() == 4),
"MKLDNN Softmax operation is only supported for 2D and 4D tensors");
auto int64Ty = rewriter.getIntegerType(64);
auto unrankedMemrefTy = UnrankedMemRefType::get(elemTy, 0);
auto axes = softmax.axes().getValue();
opAttrs attrs;
attrs.intAttr = axes[0].cast<IntegerAttr>().getInt();
auto index = pass.insertAttrs(attrs);
auto attrsIndexArg =
rewriter.create<mlir::ConstantIntOp>(rewriter.getUnknownLoc(), index, 64);
auto opTypeArg = rewriter.create<mlir::ConstantIntOp>(
rewriter.getUnknownLoc(), static_cast<int64_t>(OpType::SOFTMAX), 64);
auto index = pass.insertAttrs(attrs, AttrsType::INT);
// Get callback func
auto module = pass.getModule();
auto* llvmDialect = module.getContext()->getRegisteredDialect<mlir::LLVM::LLVMDialect>();
auto unionTy = getLLVMType(AttrsType::CONV3D, llvmDialect);
auto int64Ty = rewriter.getIntegerType(64);
auto unrankedMemrefTy = UnrankedMemRefType::get(elemTy, 0);
FuncOp callBackFunc =
pass.getCallDecl("callback_1_input",
{unrankedMemrefTy, unrankedMemrefTy, int64Ty, int64Ty},
{unrankedMemrefTy, unrankedMemrefTy, unionTy.getPointerTo(), int64Ty},
{},
rewriter);
// Insert call
auto globalPtr = getGlobalAddr(index, rewriter, pass);
auto opTypeArg = rewriter.create<mlir::ConstantIntOp>(
rewriter.getUnknownLoc(), static_cast<int64_t>(OpType::SOFTMAX), 64);
SmallVector<mlir::Value, 4> inputs = {lhs, result};
SmallVector<mlir::Value, 4> outputs;
castMemRef(inputs, outputs, rewriter, unrankedMemrefTy);
SmallVector<mlir::Value, 4> args = {outputs[0], outputs[1], attrsIndexArg, opTypeArg};
SmallVector<mlir::Value, 4> args = {outputs[0], outputs[1], globalPtr, opTypeArg};
rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args);
rewriter.replaceOp(op, result);
return matchSuccess();
}
REWRITER(NGConvBiasOp)
{
auto convBias = cast<NGConvBiasOp>(op);
auto loc = convBias.getLoc();
ScopedContext scope(rewriter, loc);
// Get operands
Value result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result, "Unexpected null result in ConvBias Op");
Value images = operands[0];
Value filters = operands[1];
Value bias = operands[2];
auto strides = convBias.strides().getValue();
auto dilation = convBias.dilation().getValue();
auto padBelow = convBias.padBelow().getValue();
auto padAbove = convBias.padBelow().getValue();
auto resultTy = result.getType().dyn_cast<MemRefType>();
auto resultShape = resultTy.getShape();
auto imagesTy = images.getType().dyn_cast<MemRefType>();
auto imagesShape = imagesTy.getShape();
NGRAPH_CHECK(resultTy, "Unexpected non-memref result type");
NGRAPH_CHECK(imagesTy, "Unexpected non-memref LHS type");
Type elemTy = resultTy.getElementType();
NGRAPH_CHECK(elemTy == imagesTy.getElementType(), "Types mismatch in ConvBias");
NGRAPH_CHECK((imagesShape.size() == 3 && resultShape.size() == 3) ||
(imagesShape.size() == 4 && resultShape.size() == 4) ||
(imagesShape.size() == 5 && resultShape.size() == 5),
"MKLDNN conv operation is only supported for 3D, 4D, and 5D tensors");
opAttrs attrs;
size_t index = 0;
if (imagesShape.size() == 3)
{
attrs.convAttrs1d.withRelu = convBias.withRelu();
attrs.convAttrs1d.windowStrides[0] = strides[0].cast<IntegerAttr>().getInt();
attrs.convAttrs1d.windowDilation[0] = dilation[0].cast<IntegerAttr>().getInt();
attrs.convAttrs1d.padBelow[0] = padBelow[0].cast<IntegerAttr>().getInt();
attrs.convAttrs1d.padAbove[0] = padAbove[0].cast<IntegerAttr>().getInt();
index = pass.insertAttrs(attrs, AttrsType::CONV1D);
}
else if (imagesShape.size() == 4)
{
attrs.convAttrs2d.withRelu = convBias.withRelu();
for (auto i = 0; i < 2; i++)
{
attrs.convAttrs2d.windowStrides[i] = strides[i].cast<IntegerAttr>().getInt();
attrs.convAttrs2d.windowDilation[i] = dilation[i].cast<IntegerAttr>().getInt();
attrs.convAttrs2d.padBelow[i] = padBelow[i].cast<IntegerAttr>().getInt();
attrs.convAttrs2d.padAbove[i] = padAbove[i].cast<IntegerAttr>().getInt();
}
index = pass.insertAttrs(attrs, AttrsType::CONV2D);
}
else if (imagesShape.size() == 5)
{
attrs.convAttrs3d.withRelu = convBias.withRelu();
for (auto i = 0; i < 3; i++)
{
attrs.convAttrs3d.windowStrides[i] = strides[i].cast<IntegerAttr>().getInt();
attrs.convAttrs3d.windowDilation[i] = dilation[i].cast<IntegerAttr>().getInt();
attrs.convAttrs3d.padBelow[i] = padBelow[i].cast<IntegerAttr>().getInt();
attrs.convAttrs3d.padAbove[i] = padAbove[i].cast<IntegerAttr>().getInt();
}
index = pass.insertAttrs(attrs, AttrsType::CONV3D);
}
// Get callback func
auto module = pass.getModule();
auto* llvmDialect = module.getContext()->getRegisteredDialect<mlir::LLVM::LLVMDialect>();
auto unionTy = getLLVMType(AttrsType::CONV3D, llvmDialect);
auto int64Ty = rewriter.getIntegerType(64);
auto unrankedMemrefTy = UnrankedMemRefType::get(elemTy, 0);
FuncOp callBackFunc = pass.getCallDecl("callback_3_inputs",
{unrankedMemrefTy,
unrankedMemrefTy,
unrankedMemrefTy,
unrankedMemrefTy,
unionTy.getPointerTo(),
int64Ty},
{},
rewriter);
// Insert call
auto globalPtr = getGlobalAddr(index, rewriter, pass);
auto opTypeArg = rewriter.create<mlir::ConstantIntOp>(
rewriter.getUnknownLoc(), static_cast<int64_t>(OpType::CONVOLUTIONBIAS), 64);
SmallVector<mlir::Value, 4> inputs = {images, filters, bias, result};
SmallVector<mlir::Value, 4> outputs;
castMemRef(inputs, outputs, rewriter, unrankedMemrefTy);
SmallVector<mlir::Value, 6> args = {
outputs[0], outputs[1], outputs[2], outputs[3], globalPtr, opTypeArg};
rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args);
rewriter.replaceOp(op, result);
......@@ -2114,18 +2814,8 @@ namespace
NGRAPH_UNREACHABLE("Unsupported pooling op");
}
auto unrankedMemrefTy = UnrankedMemRefType::get(elemTy, 0);
SmallVector<mlir::Value, 4> inputs = {lhs, result};
SmallVector<mlir::Value, 4> outputs;
castMemRef(inputs, outputs, rewriter, unrankedMemrefTy);
FuncOp callBackFunc =
pass.getCallDecl("callback_1_input",
{unrankedMemrefTy, unrankedMemrefTy, int64Ty, int64Ty},
{},
rewriter);
opAttrs attrs;
size_t index = 0;
if (lhsShape.size() == 4)
{
attrs.poolAttrs2d.includePaddingInAvgComputation = includePadding;
......@@ -2136,6 +2826,7 @@ namespace
attrs.poolAttrs2d.padBelow[i] = padBelow[i].cast<IntegerAttr>().getInt();
attrs.poolAttrs2d.padAbove[i] = padAbove[i].cast<IntegerAttr>().getInt();
}
index = pass.insertAttrs(attrs, AttrsType::POOL2D);
}
else if (lhsShape.size() == 5)
{
......@@ -2147,14 +2838,26 @@ namespace
attrs.poolAttrs3d.padBelow[i] = padBelow[i].cast<IntegerAttr>().getInt();
attrs.poolAttrs3d.padAbove[i] = padAbove[i].cast<IntegerAttr>().getInt();
}
index = pass.insertAttrs(attrs, AttrsType::POOL3D);
}
auto index = pass.insertAttrs(attrs);
auto attrsIndexArg =
rewriter.create<mlir::ConstantIntOp>(rewriter.getUnknownLoc(), index, 64);
// Get callback func
auto module = pass.getModule();
auto* llvmDialect = module.getContext()->getRegisteredDialect<mlir::LLVM::LLVMDialect>();
auto unionTy = getLLVMType(AttrsType::CONV3D, llvmDialect);
auto unrankedMemrefTy = UnrankedMemRefType::get(elemTy, 0);
FuncOp callBackFunc =
pass.getCallDecl("callback_1_input",
{unrankedMemrefTy, unrankedMemrefTy, unionTy.getPointerTo(), int64Ty},
{},
rewriter);
// Insert call
auto globalPtr = getGlobalAddr(index, rewriter, pass);
auto opTypeArg = rewriter.create<mlir::ConstantIntOp>(
rewriter.getUnknownLoc(), static_cast<int64_t>(ty), 64);
SmallVector<mlir::Value, 4> args = {outputs[0], outputs[1], attrsIndexArg, opTypeArg};
SmallVector<mlir::Value, 4> inputs = {lhs, result};
SmallVector<mlir::Value, 4> outputs;
castMemRef(inputs, outputs, rewriter, unrankedMemrefTy);
SmallVector<mlir::Value, 4> args = {outputs[0], outputs[1], globalPtr, opTypeArg};
rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args);
rewriter.replaceOp(op, result);
}
......
......@@ -31,6 +31,7 @@ MLIR_OP(NGAvgPoolOp , false )
MLIR_OP(NGAvgPoolBackpropOp , false )
MLIR_OP(NGConcatOp , true )
MLIR_OP(NGConvolutionOp , false )
MLIR_OP(NGConvBiasOp , false )
MLIR_OP(NGDivOp , true )
MLIR_OP(NGDotOp , false )
MLIR_OP(NGGatherOp , false )
......
......@@ -942,8 +942,8 @@ def NGDepthToSpaceOp :
def NGConvBiasOp :
NG_OneResult_Op<"convBias", [NoSideEffect, DeclareOpInterfaceMethods<FusedOp>]>,
Arguments<(ins NG_TensorType:$images, NG_TensorType:$filters, NG_TensorType:$bias,
I64ArrayAttr:$strides, I64ArrayAttr:$padBelow, I64ArrayAttr:$padAbove,
DefaultValuedAttr<BoolAttr, "false">:$withRelu)>
I64ArrayAttr:$strides, I64ArrayAttr:$dilation, I64ArrayAttr:$padBelow,
I64ArrayAttr:$padAbove, DefaultValuedAttr<BoolAttr, "false">:$withRelu)>
{
let summary = "Convolution Bias Op";
let description = "Convolution + bias forward prop for batched convolution operation.";
......@@ -967,9 +967,10 @@ def NGConvBiasOp :
let extraClassDeclaration = [{
void setStrides(const ArrayAttr& attr) { this->setAttr("strides", attr); }
void setDilation(const ArrayAttr& attr) { this->setAttr("dilation", attr); }
void setPadAbove(const ArrayAttr& attr) { this->setAttr("padAbove", attr); }
void setPadBelow(const ArrayAttr& attr) { this->setAttr("padBelow", attr); }
void setWithRelu(const Attribute& attr) {this->setAttr("withRelu", attr); }
void setWithRelu(const Attribute& attr) { this->setAttr("withRelu", attr); }
}];
}
......
......@@ -12,6 +12,7 @@ MLIR_OP(Divide)
MLIR_OP(Dot)
MLIR_OP(Concat)
MLIR_OP(Convolution)
MLIR_OP(ConvolutionBias)
MLIR_OP(Gather)
MLIR_OP(Gemm)
MLIR_OP(Greater)
......
......@@ -420,6 +420,47 @@ void MLIRSubgraphExtractionPass::sanity_check(std::shared_ptr<Function> func, No
}
}
// Check if convolution related nodes such as Convolution, ConvolutionBias,
// ConvolutionRelu, ... can use callback.
template <typename T>
static bool can_use_mkldnn_conv_callback(ngraph::Node* node)
{
auto convolution = static_cast<const T*>(node);
auto arg0_rank = node->get_input_shape(0).size();
auto dilation = convolution->get_data_dilation_strides();
if (std::any_of(dilation.begin(), dilation.end(), [](size_t s) { return s != 1; }))
{
return false;
}
// MKLDNN doesnt support negative padding
auto pad_above = convolution->get_padding_above();
if (std::any_of(pad_above.begin(), pad_above.end(), [](size_t s) { return s < 0; }))
{
return false;
}
auto pad_below = convolution->get_padding_below();
if (std::any_of(pad_below.begin(), pad_below.end(), [](size_t s) { return s < 0; }))
{
return false;
}
if (arg0_rank != 3 && arg0_rank != 4 && arg0_rank != 5)
{
return false;
}
// Only support f32 for now
if (node->get_input_element_type(0) != ngraph::element::f32 ||
node->get_input_element_type(1) != ngraph::element::f32 ||
node->get_output_element_type(0) != ngraph::element::f32)
{
return false;
}
return true;
}
bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node)
{
if (is_type<Parameter>(node) || is_type<Result>(node))
......@@ -474,6 +515,16 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
std::all_of(window_dilation.begin(), window_dilation.end(), is_one);
}
if (is_type<ngraph::op::ConvolutionBias>(node))
{
// ConvBias is only supported through callback
if (!getenv_bool("NGRAPH_MLIR_CALLBACK"))
{
return false;
}
return can_use_mkldnn_conv_callback<ngraph::op::ConvolutionBias>(node.get());
}
// MKLDNN only supports softmax across single axis
if (auto softmax = as_type_ptr<ngraph::op::Softmax>(node))
{
......@@ -552,7 +603,8 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
if (is_type<ngraph::op::MatMul>(node))
{
// MatMul is only supported through callback
if (!getenv_bool("NGRAPH_MLIR_CALLBACK"))
if (!getenv_bool("NGRAPH_MLIR_CALLBACK") || node->get_input_shape(0).size() != 2 ||
node->get_input_shape(1).size() != 2)
{
return false;
}
......@@ -561,7 +613,8 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
if (is_type<ngraph::op::Gemm>(node))
{
// Gemm is only supported through callback
if (!getenv_bool("NGRAPH_MLIR_CALLBACK"))
if (!getenv_bool("NGRAPH_MLIR_CALLBACK") || node->get_input_shape(0).size() != 2 ||
node->get_input_shape(1).size() != 2)
{
return false;
}
......
......@@ -478,6 +478,21 @@ mlir::Operation* NgDialectConversionPass::COMPILE_OP_DECL(ngraph::op::GroupConvo
return op;
}
template <>
mlir::Operation* NgDialectConversionPass::COMPILE_OP_DECL(ngraph::op::ConvolutionBias)
{
mlir::Operation* op = NgDialectObj.createGenericOp<mlir::NGConvBiasOp>(ngNode);
auto convNode = static_cast<const ngraph::op::ConvolutionBias*>(ngNode);
auto convOp = llvm::cast<mlir::NGConvBiasOp>(op);
convOp.setStrides(NgDialectObj.getShapeAsAttr(convNode->get_window_movement_strides()));
convOp.setDilation(NgDialectObj.getShapeAsAttr(convNode->get_window_dilation_strides()));
convOp.setPadBelow(NgDialectObj.getShapeAsAttr(convNode->get_padding_below()));
convOp.setPadAbove(NgDialectObj.getShapeAsAttr(convNode->get_padding_above()));
convOp.setWithRelu(NgDialectObj.m_builder.getBoolAttr(convNode->with_relu()));
return op;
}
template <>
mlir::Operation* NgDialectConversionPass::COMPILE_OP_DECL(ngraph::op::AvgPool)
{
......
......@@ -94,6 +94,16 @@ namespace ngraph
};
// These structs and union are used to pass attributes to callbacks.
template <int N>
struct convAttrs
{
bool withRelu;
int64_t windowStrides[N];
int64_t windowDilation[N];
int64_t padBelow[N];
int64_t padAbove[N];
};
template <int N>
struct poolAttrs
{
......@@ -119,8 +129,22 @@ namespace ngraph
BroadcastType broadcastHint;
};
enum class AttrsType
{
INT = 0,
CONV1D,
CONV2D,
CONV3D,
POOL2D,
POOL3D,
GEMM
};
union opAttrs {
int intAttr;
int64_t intAttr;
convAttrs<1> convAttrs1d;
convAttrs<2> convAttrs2d;
convAttrs<3> convAttrs3d;
poolAttrs<2> poolAttrs2d;
poolAttrs<3> poolAttrs3d;
gemmAttrs gemmAttrs2d;
......
......@@ -28,13 +28,7 @@
using namespace ngraph;
using namespace ngraph::runtime::ngmlir;
extern std::vector<opAttrs> opAttrsVec;
static inline opAttrs getAttrs(size_t index)
{
return opAttrsVec[index];
}
static bool inline compare_mkldnn_dims(mkldnn_dims_t& arr1, mkldnn_dims_t& arr2, size_t size)
static bool inline compareMkldnnDims(mkldnn_dims_t& arr1, mkldnn_dims_t& arr2, size_t size)
{
for (auto i = 0; i < size; i++)
{
......@@ -46,8 +40,7 @@ static bool inline compare_mkldnn_dims(mkldnn_dims_t& arr1, mkldnn_dims_t& arr2,
return true;
}
static bool
compare_mkldnn_strides_order(mkldnn_dims_t& strides1, mkldnn_dims_t& strides2, size_t size)
static bool compareMkldnnStridesOrder(mkldnn_dims_t& strides1, mkldnn_dims_t& strides2, size_t size)
{
std::vector<size_t> indices1(size, 0), indices2(size, 0);
for (size_t i = 0; i < size; i++)
......@@ -72,8 +65,7 @@ static bool
return true;
}
static bool compare_mkldnn_md_formats(const mkldnn::memory::desc& lhs,
const mkldnn::memory::desc& rhs)
static bool compareMkldnnMdFormats(const mkldnn::memory::desc& lhs, const mkldnn::memory::desc& rhs)
{
mkldnn_memory_desc_t md1 = lhs.data, md2 = rhs.data;
......@@ -97,59 +89,58 @@ static bool compare_mkldnn_md_formats(const mkldnn::memory::desc& lhs,
auto blk2 = md2.format_desc.blocking;
if (blk1.inner_nblks != blk2.inner_nblks ||
!compare_mkldnn_dims(blk1.inner_blks, blk2.inner_blks, blk1.inner_nblks) ||
!compare_mkldnn_dims(blk1.inner_idxs, blk2.inner_idxs, blk1.inner_nblks))
!compareMkldnnDims(blk1.inner_blks, blk2.inner_blks, blk1.inner_nblks) ||
!compareMkldnnDims(blk1.inner_idxs, blk2.inner_idxs, blk1.inner_nblks))
{
return false;
}
return compare_mkldnn_strides_order(blk1.strides, blk2.strides, md1.ndims);
return compareMkldnnStridesOrder(blk1.strides, blk2.strides, md1.ndims);
}
static mkldnn::memory convert_layout_if_diff(const mkldnn::memory::desc& lhs,
static mkldnn::memory convertLayoutIfDiff(const mkldnn::memory::desc& lhs,
const mkldnn::memory::desc& rhs,
void* ptr,
mkldnn::engine cpu_engine)
mkldnn::engine cpuEngine)
{
if (!compare_mkldnn_md_formats(lhs, rhs))
{
mkldnn::memory reorder_in = {lhs, cpu_engine, ptr};
mkldnn::memory reorder_out = {rhs, cpu_engine};
mkldnn::reorder convert(reorder_in, reorder_out);
std::unordered_map<int, mkldnn::memory> exec_args = {{MKLDNN_ARG_SRC, reorder_in},
{MKLDNN_ARG_DST, reorder_out}};
mkldnn::stream s(cpu_engine);
if (!compareMkldnnMdFormats(lhs, rhs))
{
mkldnn::memory reorderIn = {lhs, cpuEngine, ptr};
mkldnn::memory reorderOut = {rhs, cpuEngine};
mkldnn::reorder convert(reorderIn, reorderOut);
std::unordered_map<int, mkldnn::memory> execArgs = {{MKLDNN_ARG_SRC, reorderIn},
{MKLDNN_ARG_DST, reorderOut}};
mkldnn::stream s(cpuEngine);
try
{
convert.execute(s, exec_args);
convert.execute(s, execArgs);
s.wait();
}
catch (const mkldnn::error& e)
{
throw ngraph_error("Could not run mkdnn primitive " + std::string(e.message));
}
return reorder_out;
return reorderOut;
}
else
{
return mkldnn::memory{lhs, cpu_engine, ptr};
return mkldnn::memory{lhs, cpuEngine, ptr};
}
}
static void convert_output_layout(const mkldnn::memory::desc& lhs,
static void convertOutputLayout(mkldnn::memory& reorderIn,
const mkldnn::memory::desc& rhs,
void* ptr,
mkldnn::engine cpu_engine)
mkldnn::engine cpuEngine)
{
mkldnn::memory reorder_in = {rhs, cpu_engine};
mkldnn::memory reorder_out = {lhs, cpu_engine, ptr};
mkldnn::reorder convert(reorder_in, reorder_out);
std::unordered_map<int, mkldnn::memory> exec_args = {{MKLDNN_ARG_SRC, reorder_in},
{MKLDNN_ARG_DST, reorder_out}};
mkldnn::stream s(cpu_engine);
mkldnn::memory reorderOut = {rhs, cpuEngine, ptr};
mkldnn::reorder convert(reorderIn, reorderOut);
std::unordered_map<int, mkldnn::memory> execArgs = {{MKLDNN_ARG_SRC, reorderIn},
{MKLDNN_ARG_DST, reorderOut}};
mkldnn::stream s(cpuEngine);
try
{
convert.execute(s, exec_args);
convert.execute(s, execArgs);
s.wait();
}
catch (const mkldnn::error& e)
......@@ -158,12 +149,197 @@ static void convert_output_layout(const mkldnn::memory::desc& lhs,
}
}
static mkldnn::algorithm getConvAlgo()
{
#if defined(NGRAPH_ENABLE_CPU_CONV_AUTO)
return mkldnn::algorithm::convolution_auto;
#else
return mkldnn::algorithm::convolution_direct;
#endif
}
/// Callback for ConvBias
static void __mlir_mkldnn_convbias(size_t rank,
StaticMemRef* memRefData,
StaticMemRef* memRefWeights,
StaticMemRef* memRefBias,
StaticMemRef* memRefOutput,
opAttrs* attrsPtr)
{
mkldnn::memory::dims dataDims(rank);
mkldnn::memory::dims dataStrides(rank);
mkldnn::memory::dims weightsDims(rank);
mkldnn::memory::dims weightsStrides(rank);
mkldnn::memory::dims biasDims(1);
mkldnn::memory::dims biasStrides(1);
mkldnn::memory::dims resultDims(rank);
mkldnn::memory::dims resultStrides(rank);
biasDims[0] = memRefBias->shapeAndStrides[0];
biasStrides[0] = memRefBias->shapeAndStrides[1];
for (auto i = 0; i < rank; i++)
{
dataDims[i] = memRefData->shapeAndStrides[i];
dataStrides[i] = memRefData->shapeAndStrides[rank + i];
weightsDims[i] = memRefWeights->shapeAndStrides[i];
weightsStrides[i] = memRefWeights->shapeAndStrides[rank + i];
resultDims[i] = memRefOutput->shapeAndStrides[i];
resultStrides[i] = memRefOutput->shapeAndStrides[rank + i];
}
// build mkldnn primitive and execute
mkldnn::algorithm alg = getConvAlgo();
mkldnn::memory::data_type dtype = mkldnn::memory::data_type::f32;
auto dataDesc = mkldnn::memory::desc(dataDims, dtype, mkldnn::memory::FORMAT::any);
auto dataDescOrigin = mkldnn::memory::desc(dataDims, dtype, dataStrides);
auto weightsDesc = mkldnn::memory::desc(weightsDims, dtype, mkldnn::memory::FORMAT::any);
auto weightsDescOrigin = mkldnn::memory::desc(weightsDims, dtype, weightsStrides);
auto biasDesc = mkldnn::memory::desc(biasDims, dtype, mkldnn::memory::FORMAT::any);
auto resultDesc = mkldnn::memory::desc(resultDims, dtype, mkldnn::memory::FORMAT::any);
auto resultDescOrigin = mkldnn::memory::desc(resultDims, dtype, resultStrides);
mkldnn::primitive_attr attr;
mkldnn::engine cpuEngine(mkldnn::engine::kind::cpu, 0);
mkldnn::convolution_forward::primitive_desc convPd;
mkldnn::post_ops ops;
const float opsScale = 1.f;
const float opsAlpha = -0.f; // relu negative slope
const float opsBeta = 0.f;
ops.append_eltwise(opsScale, mkldnn::algorithm::eltwise_relu, opsAlpha, opsBeta);
if (rank == 3)
{
auto convAttrs = (*attrsPtr).convAttrs1d;
try
{
auto convDesc = mkldnn::convolution_forward::desc(
mkldnn::prop_kind::forward_inference,
alg,
dataDesc,
weightsDesc,
biasDesc,
resultDesc,
mkldnn::memory::dims{convAttrs.windowStrides[0]},
mkldnn::memory::dims{convAttrs.windowDilation[0] - 1},
mkldnn::memory::dims{convAttrs.padBelow[0]},
mkldnn::memory::dims{convAttrs.padAbove[0]});
if (convAttrs.withRelu)
{
attr.set_post_ops(ops);
}
convPd = mkldnn::convolution_forward::primitive_desc(convDesc, attr, cpuEngine);
}
catch (const mkldnn::error& e)
{
throw ngraph_error("Could not create mkldnn conv descriptor " + std::string(e.message));
}
}
else if (rank == 4)
{
auto convAttrs = (*attrsPtr).convAttrs2d;
try
{
auto convDesc = mkldnn::convolution_forward::desc(
mkldnn::prop_kind::forward_inference,
alg,
dataDesc,
weightsDesc,
biasDesc,
resultDesc,
mkldnn::memory::dims{convAttrs.windowStrides[0], convAttrs.windowStrides[1]},
mkldnn::memory::dims{convAttrs.windowDilation[0] - 1,
convAttrs.windowDilation[1] - 1},
mkldnn::memory::dims{convAttrs.padBelow[0], convAttrs.padBelow[1]},
mkldnn::memory::dims{convAttrs.padAbove[0], convAttrs.padAbove[1]});
if (convAttrs.withRelu)
{
attr.set_post_ops(ops);
}
convPd = mkldnn::convolution_forward::primitive_desc(convDesc, attr, cpuEngine);
}
catch (const mkldnn::error& e)
{
throw ngraph_error("Could not create mkldnn conv descriptor " + std::string(e.message));
}
}
else if (rank == 5)
{
auto convAttrs = (*attrsPtr).convAttrs3d;
try
{
auto convDesc = mkldnn::convolution_forward::desc(
mkldnn::prop_kind::forward_inference,
alg,
dataDesc,
weightsDesc,
biasDesc,
resultDesc,
mkldnn::memory::dims{convAttrs.windowStrides[0],
convAttrs.windowStrides[1],
convAttrs.windowStrides[2]},
mkldnn::memory::dims{convAttrs.windowDilation[0] - 1,
convAttrs.windowDilation[1] - 1,
convAttrs.windowDilation[2] - 1},
mkldnn::memory::dims{
convAttrs.padBelow[0], convAttrs.padBelow[1], convAttrs.padBelow[2]},
mkldnn::memory::dims{
convAttrs.padAbove[0], convAttrs.padAbove[1], convAttrs.padAbove[2]});
if (convAttrs.withRelu)
{
attr.set_post_ops(ops);
}
convPd = mkldnn::convolution_forward::primitive_desc(convDesc, attr, cpuEngine);
}
catch (const mkldnn::error& e)
{
throw ngraph_error("Could not create mkldnn conv descriptor " + std::string(e.message));
}
}
mkldnn::convolution_forward conv(convPd);
mkldnn::memory data =
convertLayoutIfDiff(dataDescOrigin, convPd.src_desc(), memRefData->allocatedPtr, cpuEngine);
mkldnn::memory weights = convertLayoutIfDiff(
weightsDescOrigin, convPd.weights_desc(), memRefWeights->allocatedPtr, cpuEngine);
mkldnn::memory bias{convPd.bias_desc(), cpuEngine, memRefBias->allocatedPtr};
mkldnn::memory out;
bool needConvert = false;
if (!compareMkldnnMdFormats(resultDescOrigin, convPd.dst_desc()))
{
out = mkldnn::memory(convPd.dst_desc(), cpuEngine);
needConvert = true;
}
else
{
out = mkldnn::memory(convPd.dst_desc(), cpuEngine, memRefOutput->allocatedPtr);
}
std::unordered_map<int, mkldnn::memory> execArgs = {{MKLDNN_ARG_SRC, data},
{MKLDNN_ARG_WEIGHTS, weights},
{MKLDNN_ARG_BIAS, bias},
{MKLDNN_ARG_DST, out}};
mkldnn::stream s(cpuEngine);
try
{
conv.execute(s, execArgs);
s.wait();
}
catch (const mkldnn::error& e)
{
throw ngraph_error("Could not run mkdnn primitive " + std::string(e.message));
}
if (needConvert)
{
convertOutputLayout(out, resultDescOrigin, memRefOutput->allocatedPtr, cpuEngine);
}
}
/// Callback for MaxPoolBackprop
static void __mlir_mkldnn_maxpoolbackprop(size_t rank,
StaticMemRef* memRefSrc,
StaticMemRef* memRefDelta,
StaticMemRef* memRefOutput,
size_t index)
opAttrs* attrsPtr)
{
mkldnn::memory::dims srcDims(rank);
mkldnn::memory::dims srcStrides(rank);
......@@ -182,107 +358,120 @@ static void __mlir_mkldnn_maxpoolbackprop(size_t rank,
}
// build mkldnn primitive and execute
auto required_format = rank == 4 ? mkldnn::memory::FORMAT::nchw : mkldnn::memory::FORMAT::ncdhw;
auto requiredFormat = rank == 4 ? mkldnn::memory::FORMAT::nchw : mkldnn::memory::FORMAT::ncdhw;
mkldnn::memory::data_type dtype = mkldnn::memory::data_type::f32;
auto diff_dst_desc = mkldnn::memory::desc(deltaDims, dtype, required_format);
auto diff_src_desc = mkldnn::memory::desc(outDims, dtype, required_format);
auto src_desc_origin = mkldnn::memory::desc(srcDims, dtype, srcStrides);
auto diff_dst_desc_origin = mkldnn::memory::desc(deltaDims, dtype, deltaStrides);
auto diff_src_desc_origin = mkldnn::memory::desc(outDims, dtype, outStrides);
auto diffDstDesc = mkldnn::memory::desc(deltaDims, dtype, requiredFormat);
auto diffSrcDesc = mkldnn::memory::desc(outDims, dtype, requiredFormat);
auto srcDescOrigin = mkldnn::memory::desc(srcDims, dtype, srcStrides);
auto diffDstDescOrigin = mkldnn::memory::desc(deltaDims, dtype, deltaStrides);
auto diffSrcDescOrigin = mkldnn::memory::desc(outDims, dtype, outStrides);
mkldnn::primitive_attr attr;
mkldnn::engine cpu_engine(mkldnn::engine::kind::cpu, 0);
mkldnn::pooling_forward::primitive_desc maxpool_pd_f;
mkldnn::pooling_backward::primitive_desc maxpool_pd_b;
mkldnn::engine cpuEngine(mkldnn::engine::kind::cpu, 0);
mkldnn::pooling_forward::primitive_desc maxpoolPdF;
mkldnn::pooling_backward::primitive_desc maxpoolPdB;
if (rank == 4)
{
poolAttrs<2> pAttrs = getAttrs(index).poolAttrs2d;
auto maxpool_desc_f = mkldnn::pooling_forward::desc(
poolAttrs<2> pAttrs = (*attrsPtr).poolAttrs2d;
try
{
auto maxpoolDescF = mkldnn::pooling_forward::desc(
mkldnn::prop_kind::forward_training,
mkldnn::algorithm::pooling_max,
diff_src_desc,
diff_dst_desc,
diffSrcDesc,
diffDstDesc,
mkldnn::memory::dims{pAttrs.windowStrides[0], pAttrs.windowStrides[1]},
mkldnn::memory::dims{pAttrs.windowShape[0], pAttrs.windowShape[1]},
mkldnn::memory::dims{pAttrs.padBelow[0], pAttrs.padBelow[1]},
mkldnn::memory::dims{pAttrs.padAbove[0], pAttrs.padAbove[1]});
auto maxpool_desc_b = mkldnn::pooling_backward::desc(
auto maxpoolDescB = mkldnn::pooling_backward::desc(
mkldnn::algorithm::pooling_max,
diff_src_desc,
diff_dst_desc,
diffSrcDesc,
diffDstDesc,
mkldnn::memory::dims{pAttrs.windowStrides[0], pAttrs.windowStrides[1]},
mkldnn::memory::dims{pAttrs.windowShape[0], pAttrs.windowShape[1]},
mkldnn::memory::dims{pAttrs.padBelow[0], pAttrs.padBelow[1]},
mkldnn::memory::dims{pAttrs.padAbove[0], pAttrs.padAbove[1]});
maxpool_pd_f = mkldnn::pooling_forward::primitive_desc(maxpool_desc_f, attr, cpu_engine);
maxpool_pd_b = mkldnn::pooling_backward::primitive_desc(
maxpool_desc_b, attr, cpu_engine, maxpool_pd_f);
maxpoolPdF = mkldnn::pooling_forward::primitive_desc(maxpoolDescF, attr, cpuEngine);
maxpoolPdB =
mkldnn::pooling_backward::primitive_desc(maxpoolDescB, attr, cpuEngine, maxpoolPdF);
}
catch (const mkldnn::error& e)
{
throw ngraph_error("Could not create mkldnn max pooling descriptor " +
std::string(e.message));
}
}
else if (rank == 5)
{
poolAttrs<3> pAttrs = getAttrs(index).poolAttrs3d;
auto maxpool_desc_f = mkldnn::pooling_forward::desc(
poolAttrs<3> pAttrs = (*attrsPtr).poolAttrs3d;
try
{
auto maxpoolDescF = mkldnn::pooling_forward::desc(
mkldnn::prop_kind::forward_training,
mkldnn::algorithm::pooling_max,
diff_src_desc,
diff_dst_desc,
diffSrcDesc,
diffDstDesc,
mkldnn::memory::dims{
pAttrs.windowStrides[0], pAttrs.windowStrides[1], pAttrs.windowStrides[2]},
mkldnn::memory::dims{
pAttrs.windowShape[0], pAttrs.windowShape[1], pAttrs.windowShape[2]},
mkldnn::memory::dims{pAttrs.padBelow[0], pAttrs.padBelow[1], pAttrs.padBelow[2]},
mkldnn::memory::dims{pAttrs.padAbove[0], pAttrs.padAbove[1], pAttrs.padAbove[2]});
auto maxpool_desc_b = mkldnn::pooling_backward::desc(
auto maxpoolDescB = mkldnn::pooling_backward::desc(
mkldnn::algorithm::pooling_max,
diff_src_desc,
diff_dst_desc,
diffSrcDesc,
diffDstDesc,
mkldnn::memory::dims{
pAttrs.windowStrides[0], pAttrs.windowStrides[1], pAttrs.windowStrides[2]},
mkldnn::memory::dims{
pAttrs.windowShape[0], pAttrs.windowShape[1], pAttrs.windowShape[2]},
mkldnn::memory::dims{pAttrs.padBelow[0], pAttrs.padBelow[1], pAttrs.padBelow[2]},
mkldnn::memory::dims{pAttrs.padAbove[0], pAttrs.padAbove[1], pAttrs.padAbove[2]});
auto maxpool_pd_f =
mkldnn::pooling_forward::primitive_desc(maxpool_desc_f, attr, cpu_engine);
maxpool_pd_f = mkldnn::pooling_forward::primitive_desc(maxpool_desc_f, attr, cpu_engine);
maxpool_pd_b = mkldnn::pooling_backward::primitive_desc(
maxpool_desc_b, attr, cpu_engine, maxpool_pd_f);
maxpoolPdF = mkldnn::pooling_forward::primitive_desc(maxpoolDescF, attr, cpuEngine);
maxpoolPdB =
mkldnn::pooling_backward::primitive_desc(maxpoolDescB, attr, cpuEngine, maxpoolPdF);
}
catch (const mkldnn::error& e)
{
throw ngraph_error("Could not create mkldnn max pooling descriptor " +
std::string(e.message));
}
}
mkldnn::pooling_forward maxpool_f(maxpool_pd_f);
mkldnn::memory src_mem = convert_layout_if_diff(
src_desc_origin, maxpool_pd_b.diff_src_desc(), memRefSrc->allocatedPtr, cpu_engine);
mkldnn::memory dst_mem{maxpool_pd_b.diff_dst_desc(), cpu_engine};
mkldnn::memory workspace{maxpool_pd_f.workspace_desc(), cpu_engine};
mkldnn::pooling_forward maxpoolF(maxpoolPdF);
mkldnn::memory srcMem = convertLayoutIfDiff(
srcDescOrigin, maxpoolPdB.diff_src_desc(), memRefSrc->allocatedPtr, cpuEngine);
mkldnn::memory dstMem{maxpoolPdB.diff_dst_desc(), cpuEngine};
mkldnn::memory workspace{maxpoolPdF.workspace_desc(), cpuEngine};
mkldnn::pooling_backward maxpool_b(maxpool_pd_b);
mkldnn::memory diff_dst = convert_layout_if_diff(
diff_dst_desc_origin, maxpool_pd_b.diff_dst_desc(), memRefDelta->allocatedPtr, cpu_engine);
mkldnn::memory diff_src;
bool need_convert = false;
if (!compare_mkldnn_md_formats(diff_src_desc_origin, maxpool_pd_b.diff_src_desc()))
mkldnn::pooling_backward maxpoolB(maxpoolPdB);
mkldnn::memory diffDst = convertLayoutIfDiff(
diffDstDescOrigin, maxpoolPdB.diff_dst_desc(), memRefDelta->allocatedPtr, cpuEngine);
mkldnn::memory diffSrc;
bool needConvert = false;
if (!compareMkldnnMdFormats(diffSrcDescOrigin, maxpoolPdB.diff_src_desc()))
{
diff_src = mkldnn::memory(maxpool_pd_b.diff_src_desc(), cpu_engine);
need_convert = true;
diffSrc = mkldnn::memory(maxpoolPdB.diff_src_desc(), cpuEngine);
needConvert = true;
}
else
{
diff_src =
mkldnn::memory(maxpool_pd_b.diff_src_desc(), cpu_engine, memRefOutput->allocatedPtr);
diffSrc = mkldnn::memory(maxpoolPdB.diff_src_desc(), cpuEngine, memRefOutput->allocatedPtr);
}
std::unordered_map<int, mkldnn::memory> exec_args_f = {
{MKLDNN_ARG_SRC, src_mem}, {MKLDNN_ARG_WORKSPACE, workspace}, {MKLDNN_ARG_DST, dst_mem}};
std::unordered_map<int, mkldnn::memory> exec_args_b = {{MKLDNN_ARG_DIFF_DST, diff_dst},
std::unordered_map<int, mkldnn::memory> execArgsF = {
{MKLDNN_ARG_SRC, srcMem}, {MKLDNN_ARG_WORKSPACE, workspace}, {MKLDNN_ARG_DST, dstMem}};
std::unordered_map<int, mkldnn::memory> execArgsB = {{MKLDNN_ARG_DIFF_DST, diffDst},
{MKLDNN_ARG_WORKSPACE, workspace},
{MKLDNN_ARG_DIFF_SRC, diff_src}};
{MKLDNN_ARG_DIFF_SRC, diffSrc}};
mkldnn::stream s(cpu_engine);
mkldnn::stream s(cpuEngine);
try
{
maxpool_f.execute(s, exec_args_f);
maxpoolF.execute(s, execArgsF);
s.wait();
maxpool_b.execute(s, exec_args_b);
maxpoolB.execute(s, execArgsB);
s.wait();
}
catch (const mkldnn::error& e)
......@@ -290,12 +479,9 @@ static void __mlir_mkldnn_maxpoolbackprop(size_t rank,
throw ngraph_error("Could not run mkdnn primitive " + std::string(e.message));
}
if (need_convert)
if (needConvert)
{
convert_output_layout(diff_dst_desc_origin,
maxpool_pd_b.diff_dst_desc(),
memRefOutput->allocatedPtr,
cpu_engine);
convertOutputLayout(diffSrc, diffSrcDescOrigin, memRefOutput->allocatedPtr, cpuEngine);
}
}
......@@ -303,7 +489,7 @@ static void __mlir_mkldnn_maxpoolbackprop(size_t rank,
static void __mlir_mkldnn_avgpoolbackprop(size_t rank,
StaticMemRef* memRefInput,
StaticMemRef* memRefOutput,
size_t index)
opAttrs* attrsPtr)
{
mkldnn::memory::dims dims(rank);
mkldnn::memory::dims strides(rank);
......@@ -318,99 +504,115 @@ static void __mlir_mkldnn_avgpoolbackprop(size_t rank,
}
// build mkldnn primitive and execute
auto required_format = rank == 4 ? mkldnn::memory::FORMAT::nchw : mkldnn::memory::FORMAT::ncdhw;
auto requiredFormat = rank == 4 ? mkldnn::memory::FORMAT::nchw : mkldnn::memory::FORMAT::ncdhw;
mkldnn::memory::data_type dtype = mkldnn::memory::data_type::f32;
auto diff_dst_desc = mkldnn::memory::desc(dims, dtype, required_format);
auto diff_src_desc = mkldnn::memory::desc(outDims, dtype, required_format);
auto diff_dst_desc_origin = mkldnn::memory::desc(dims, dtype, strides);
auto diff_src_desc_origin = mkldnn::memory::desc(outDims, dtype, outStrides);
auto diffDstDesc = mkldnn::memory::desc(dims, dtype, requiredFormat);
auto diffSrcDesc = mkldnn::memory::desc(outDims, dtype, requiredFormat);
auto diffDstDescOrigin = mkldnn::memory::desc(dims, dtype, strides);
auto diffSrcDescOrigin = mkldnn::memory::desc(outDims, dtype, outStrides);
mkldnn::primitive_attr attr;
mkldnn::engine cpu_engine(mkldnn::engine::kind::cpu, 0);
mkldnn::pooling_backward::primitive_desc avgpool_pd_b;
mkldnn::engine cpuEngine(mkldnn::engine::kind::cpu, 0);
mkldnn::pooling_backward::primitive_desc avgpoolPdB;
if (rank == 4)
{
poolAttrs<2> pAttrs = getAttrs(index).poolAttrs2d;
auto avgpool_desc_f = mkldnn::pooling_forward::desc(
poolAttrs<2> pAttrs = (*attrsPtr).poolAttrs2d;
try
{
auto avgpoolDescF = mkldnn::pooling_forward::desc(
mkldnn::prop_kind::forward_training,
(pAttrs.includePaddingInAvgComputation
? mkldnn::algorithm::pooling_avg_include_padding
: mkldnn::algorithm::pooling_avg_exclude_padding),
diff_src_desc,
diff_dst_desc,
diffSrcDesc,
diffDstDesc,
mkldnn::memory::dims{pAttrs.windowStrides[0], pAttrs.windowStrides[1]},
mkldnn::memory::dims{pAttrs.windowShape[0], pAttrs.windowShape[1]},
mkldnn::memory::dims{pAttrs.padBelow[0], pAttrs.padBelow[1]},
mkldnn::memory::dims{pAttrs.padAbove[0], pAttrs.padAbove[1]});
auto avgpool_desc_b = mkldnn::pooling_backward::desc(
auto avgpoolDescB = mkldnn::pooling_backward::desc(
(pAttrs.includePaddingInAvgComputation
? mkldnn::algorithm::pooling_avg_include_padding
: mkldnn::algorithm::pooling_avg_exclude_padding),
diff_src_desc,
diff_dst_desc,
diffSrcDesc,
diffDstDesc,
mkldnn::memory::dims{pAttrs.windowStrides[0], pAttrs.windowStrides[1]},
mkldnn::memory::dims{pAttrs.windowShape[0], pAttrs.windowShape[1]},
mkldnn::memory::dims{pAttrs.padBelow[0], pAttrs.padBelow[1]},
mkldnn::memory::dims{pAttrs.padAbove[0], pAttrs.padAbove[1]});
auto avgpool_pd_f =
mkldnn::pooling_forward::primitive_desc(avgpool_desc_f, attr, cpu_engine);
avgpool_pd_b = mkldnn::pooling_backward::primitive_desc(
avgpool_desc_b, attr, cpu_engine, avgpool_pd_f);
auto avgpoolPdF =
mkldnn::pooling_forward::primitive_desc(avgpoolDescF, attr, cpuEngine);
avgpoolPdB =
mkldnn::pooling_backward::primitive_desc(avgpoolDescB, attr, cpuEngine, avgpoolPdF);
}
catch (const mkldnn::error& e)
{
throw ngraph_error("Could not create mkldnn avg pooling descriptor " +
std::string(e.message));
}
}
else if (rank == 5)
{
poolAttrs<3> pAttrs = getAttrs(index).poolAttrs3d;
auto avgpool_desc_f = mkldnn::pooling_forward::desc(
poolAttrs<3> pAttrs = (*attrsPtr).poolAttrs3d;
try
{
auto avgpoolDescF = mkldnn::pooling_forward::desc(
mkldnn::prop_kind::forward_training,
(pAttrs.includePaddingInAvgComputation
? mkldnn::algorithm::pooling_avg_include_padding
: mkldnn::algorithm::pooling_avg_exclude_padding),
diff_src_desc,
diff_dst_desc,
diffSrcDesc,
diffDstDesc,
mkldnn::memory::dims{
pAttrs.windowStrides[0], pAttrs.windowStrides[1], pAttrs.windowStrides[2]},
mkldnn::memory::dims{
pAttrs.windowShape[0], pAttrs.windowShape[1], pAttrs.windowShape[2]},
mkldnn::memory::dims{pAttrs.padBelow[0], pAttrs.padBelow[1], pAttrs.padBelow[2]},
mkldnn::memory::dims{pAttrs.padAbove[0], pAttrs.padAbove[1], pAttrs.padAbove[2]});
auto avgpool_desc_b = mkldnn::pooling_backward::desc(
auto avgpoolDescB = mkldnn::pooling_backward::desc(
(pAttrs.includePaddingInAvgComputation
? mkldnn::algorithm::pooling_avg_include_padding
: mkldnn::algorithm::pooling_avg_exclude_padding),
diff_src_desc,
diff_dst_desc,
diffSrcDesc,
diffDstDesc,
mkldnn::memory::dims{
pAttrs.windowStrides[0], pAttrs.windowStrides[1], pAttrs.windowStrides[2]},
mkldnn::memory::dims{
pAttrs.windowShape[0], pAttrs.windowShape[1], pAttrs.windowShape[2]},
mkldnn::memory::dims{pAttrs.padBelow[0], pAttrs.padBelow[1], pAttrs.padBelow[2]},
mkldnn::memory::dims{pAttrs.padAbove[0], pAttrs.padAbove[1], pAttrs.padAbove[2]});
auto avgpool_pd_f =
mkldnn::pooling_forward::primitive_desc(avgpool_desc_f, attr, cpu_engine);
avgpool_pd_b = mkldnn::pooling_backward::primitive_desc(
avgpool_desc_b, attr, cpu_engine, avgpool_pd_f);
auto avgpoolPdF =
mkldnn::pooling_forward::primitive_desc(avgpoolDescF, attr, cpuEngine);
avgpoolPdB =
mkldnn::pooling_backward::primitive_desc(avgpoolDescB, attr, cpuEngine, avgpoolPdF);
}
catch (const mkldnn::error& e)
{
throw ngraph_error("Could not create mkldnn avg pooling descriptor " +
std::string(e.message));
}
}
mkldnn::pooling_backward avgpool(avgpool_pd_b);
mkldnn::memory in = convert_layout_if_diff(
diff_dst_desc_origin, avgpool_pd_b.diff_dst_desc(), memRefInput->allocatedPtr, cpu_engine);
mkldnn::pooling_backward avgpool(avgpoolPdB);
mkldnn::memory in = convertLayoutIfDiff(
diffDstDescOrigin, avgpoolPdB.diff_dst_desc(), memRefInput->allocatedPtr, cpuEngine);
mkldnn::memory out;
bool need_convert = false;
if (!compare_mkldnn_md_formats(diff_src_desc_origin, avgpool_pd_b.diff_src_desc()))
bool needConvert = false;
if (!compareMkldnnMdFormats(diffSrcDescOrigin, avgpoolPdB.diff_src_desc()))
{
out = mkldnn::memory(avgpool_pd_b.diff_src_desc(), cpu_engine);
need_convert = true;
out = mkldnn::memory(avgpoolPdB.diff_src_desc(), cpuEngine);
needConvert = true;
}
else
{
out = mkldnn::memory(avgpool_pd_b.diff_src_desc(), cpu_engine, memRefOutput->allocatedPtr);
out = mkldnn::memory(avgpoolPdB.diff_src_desc(), cpuEngine, memRefOutput->allocatedPtr);
}
std::unordered_map<int, mkldnn::memory> exec_args = {{MKLDNN_ARG_DIFF_DST, in},
std::unordered_map<int, mkldnn::memory> execArgs = {{MKLDNN_ARG_DIFF_DST, in},
{MKLDNN_ARG_DIFF_SRC, out}};
mkldnn::stream s(cpu_engine);
mkldnn::stream s(cpuEngine);
try
{
avgpool.execute(s, exec_args);
avgpool.execute(s, execArgs);
s.wait();
}
catch (const mkldnn::error& e)
......@@ -418,18 +620,18 @@ static void __mlir_mkldnn_avgpoolbackprop(size_t rank,
throw ngraph_error("Could not run mkdnn primitive " + std::string(e.message));
}
if (need_convert)
if (needConvert)
{
convert_output_layout(diff_dst_desc_origin,
avgpool_pd_b.diff_dst_desc(),
memRefOutput->allocatedPtr,
cpu_engine);
convertOutputLayout(out, diffSrcDescOrigin, memRefOutput->allocatedPtr, cpuEngine);
}
}
/// Callback for AvgPool and MaxPool
static void __mlir_mkldnn_pooling(
size_t rank, StaticMemRef* memRefInput, StaticMemRef* memRefOutput, size_t index, OpType type)
static void __mlir_mkldnn_pooling(size_t rank,
StaticMemRef* memRefInput,
StaticMemRef* memRefOutput,
opAttrs* attrsPtr,
OpType type)
{
mkldnn::memory::dims dims(rank);
mkldnn::memory::dims strides(rank);
......@@ -444,77 +646,93 @@ static void __mlir_mkldnn_pooling(
}
// build mkldnn primitive and execute
auto required_format = rank == 4 ? mkldnn::memory::FORMAT::nchw : mkldnn::memory::FORMAT::ncdhw;
auto requiredFormat = rank == 4 ? mkldnn::memory::FORMAT::nchw : mkldnn::memory::FORMAT::ncdhw;
mkldnn::memory::data_type dtype = mkldnn::memory::data_type::f32;
auto input_desc = mkldnn::memory::desc(dims, dtype, required_format);
auto result_desc = mkldnn::memory::desc(outDims, dtype, required_format);
auto input_desc_origin = mkldnn::memory::desc(dims, dtype, strides);
auto result_desc_origin = mkldnn::memory::desc(outDims, dtype, outStrides);
auto inputDesc = mkldnn::memory::desc(dims, dtype, requiredFormat);
auto resultDesc = mkldnn::memory::desc(outDims, dtype, requiredFormat);
auto inputDescOrigin = mkldnn::memory::desc(dims, dtype, strides);
auto resultDescOrigin = mkldnn::memory::desc(outDims, dtype, outStrides);
mkldnn::primitive_attr attr;
mkldnn::engine cpu_engine(mkldnn::engine::kind::cpu, 0);
mkldnn::pooling_forward::primitive_desc pool_pd;
mkldnn::engine cpuEngine(mkldnn::engine::kind::cpu, 0);
mkldnn::pooling_forward::primitive_desc poolPd;
if (rank == 4)
{
poolAttrs<2> pAttrs = getAttrs(index).poolAttrs2d;
poolAttrs<2> pAttrs = (*attrsPtr).poolAttrs2d;
mkldnn::algorithm alg = type == OpType::MAXPOOL
? mkldnn::algorithm::pooling_max
: (pAttrs.includePaddingInAvgComputation
? mkldnn::algorithm::pooling_avg_include_padding
: mkldnn::algorithm::pooling_avg_exclude_padding);
auto pool_desc = mkldnn::pooling_forward::desc(
try
{
auto poolDesc = mkldnn::pooling_forward::desc(
mkldnn::prop_kind::forward_inference,
alg,
input_desc,
result_desc,
inputDesc,
resultDesc,
mkldnn::memory::dims{pAttrs.windowStrides[0], pAttrs.windowStrides[1]},
mkldnn::memory::dims{pAttrs.windowShape[0], pAttrs.windowShape[1]},
mkldnn::memory::dims{pAttrs.padBelow[0], pAttrs.padBelow[1]},
mkldnn::memory::dims{pAttrs.padAbove[0], pAttrs.padAbove[1]});
pool_pd = mkldnn::pooling_forward::primitive_desc(pool_desc, attr, cpu_engine);
poolPd = mkldnn::pooling_forward::primitive_desc(poolDesc, attr, cpuEngine);
}
catch (const mkldnn::error& e)
{
throw ngraph_error("Could not create mkldnn pooling descriptor " +
std::string(e.message));
}
}
else if (rank == 5)
{
poolAttrs<3> pAttrs = getAttrs(index).poolAttrs3d;
poolAttrs<3> pAttrs = (*attrsPtr).poolAttrs3d;
mkldnn::algorithm alg = type == OpType::MAXPOOL
? mkldnn::algorithm::pooling_max
: (pAttrs.includePaddingInAvgComputation
? mkldnn::algorithm::pooling_avg_include_padding
: mkldnn::algorithm::pooling_avg_exclude_padding);
auto pool_desc = mkldnn::pooling_forward::desc(
try
{
auto poolDesc = mkldnn::pooling_forward::desc(
mkldnn::prop_kind::forward_inference,
alg,
input_desc,
result_desc,
inputDesc,
resultDesc,
mkldnn::memory::dims{
pAttrs.windowStrides[0], pAttrs.windowStrides[1], pAttrs.windowStrides[2]},
mkldnn::memory::dims{
pAttrs.windowShape[0], pAttrs.windowShape[1], pAttrs.windowShape[2]},
mkldnn::memory::dims{pAttrs.padBelow[0], pAttrs.padBelow[1], pAttrs.padBelow[2]},
mkldnn::memory::dims{pAttrs.padAbove[0], pAttrs.padAbove[1], pAttrs.padAbove[2]});
pool_pd = mkldnn::pooling_forward::primitive_desc(pool_desc, attr, cpu_engine);
poolPd = mkldnn::pooling_forward::primitive_desc(poolDesc, attr, cpuEngine);
}
catch (const mkldnn::error& e)
{
throw ngraph_error("Could not create mkldnn pooing descriptor " +
std::string(e.message));
}
}
mkldnn::pooling_forward pool(pool_pd);
mkldnn::memory in = convert_layout_if_diff(
input_desc_origin, pool_pd.src_desc(), memRefInput->allocatedPtr, cpu_engine);
mkldnn::pooling_forward pool(poolPd);
mkldnn::memory in = convertLayoutIfDiff(
inputDescOrigin, poolPd.src_desc(), memRefInput->allocatedPtr, cpuEngine);
mkldnn::memory out;
bool need_convert = false;
if (!compare_mkldnn_md_formats(result_desc_origin, pool_pd.dst_desc()))
bool needConvert = false;
if (!compareMkldnnMdFormats(resultDescOrigin, poolPd.dst_desc()))
{
out = mkldnn::memory(pool_pd.dst_desc(), cpu_engine);
need_convert = true;
out = mkldnn::memory(poolPd.dst_desc(), cpuEngine);
needConvert = true;
}
else
{
out = mkldnn::memory(pool_pd.dst_desc(), cpu_engine, memRefOutput->allocatedPtr);
out = mkldnn::memory(poolPd.dst_desc(), cpuEngine, memRefOutput->allocatedPtr);
}
std::unordered_map<int, mkldnn::memory> exec_args = {{MKLDNN_ARG_SRC, in},
std::unordered_map<int, mkldnn::memory> execArgs = {{MKLDNN_ARG_SRC, in},
{MKLDNN_ARG_DST, out}};
mkldnn::stream s(cpu_engine);
mkldnn::stream s(cpuEngine);
try
{
pool.execute(s, exec_args);
pool.execute(s, execArgs);
s.wait();
}
catch (const mkldnn::error& e)
......@@ -522,10 +740,9 @@ static void __mlir_mkldnn_pooling(
throw ngraph_error("Could not run mkdnn primitive " + std::string(e.message));
}
if (need_convert)
if (needConvert)
{
convert_output_layout(
result_desc_origin, pool_pd.dst_desc(), memRefOutput->allocatedPtr, cpu_engine);
convertOutputLayout(out, resultDescOrigin, memRefOutput->allocatedPtr, cpuEngine);
}
}
......@@ -533,7 +750,7 @@ static void __mlir_mkldnn_pooling(
static void __mlir_mkldnn_softmax(size_t rank,
StaticMemRef* memRefInput,
StaticMemRef* memRefOutput,
int index)
opAttrs* attrsPtr)
{
mkldnn::memory::dims dims(rank);
mkldnn::memory::dims strides(rank);
......@@ -542,28 +759,36 @@ static void __mlir_mkldnn_softmax(size_t rank,
dims[i] = memRefInput->shapeAndStrides[i];
strides[i] = memRefInput->shapeAndStrides[rank + i];
}
auto softmax_axis = getAttrs(index).intAttr;
auto softmaxAxis = (*attrsPtr).intAttr;
// build mkldnn primitive and execute
mkldnn::memory::data_type dtype = mkldnn::memory::data_type::f32;
auto input_desc = mkldnn::memory::desc(dims, dtype, strides);
auto softmax_desc =
mkldnn::softmax_forward::desc(mkldnn::prop_kind::forward_scoring, input_desc, softmax_axis);
auto inputDesc = mkldnn::memory::desc(dims, dtype, strides);
mkldnn::softmax_forward::primitive_desc softmaxPd;
mkldnn::engine cpuEngine(mkldnn::engine::kind::cpu, 0);
try
{
auto softmaxDesc = mkldnn::softmax_forward::desc(
mkldnn::prop_kind::forward_scoring, inputDesc, softmaxAxis);
mkldnn::primitive_attr attr;
mkldnn::engine cpu_engine(mkldnn::engine::kind::cpu, 0);
auto softmax_pd = mkldnn::softmax_forward::primitive_desc(softmax_desc, attr, cpu_engine);
mkldnn::softmax_forward softmax(softmax_pd);
softmaxPd = mkldnn::softmax_forward::primitive_desc(softmaxDesc, attr, cpuEngine);
}
catch (const mkldnn::error& e)
{
throw ngraph_error("Could not create mkldnn softmax descriptor " + std::string(e.message));
}
mkldnn::softmax_forward softmax(softmaxPd);
mkldnn::memory in{softmax_pd.src_desc(), cpu_engine, memRefInput->allocatedPtr};
mkldnn::memory out{softmax_pd.dst_desc(), cpu_engine, memRefOutput->allocatedPtr};
mkldnn::memory in{softmaxPd.src_desc(), cpuEngine, memRefInput->allocatedPtr};
mkldnn::memory out{softmaxPd.dst_desc(), cpuEngine, memRefOutput->allocatedPtr};
std::unordered_map<int, mkldnn::memory> exec_args = {{MKLDNN_ARG_SRC, in},
std::unordered_map<int, mkldnn::memory> execArgs = {{MKLDNN_ARG_SRC, in},
{MKLDNN_ARG_DST, out}};
mkldnn::stream s(cpu_engine);
mkldnn::stream s(cpuEngine);
try
{
softmax.execute(s, exec_args);
softmax.execute(s, execArgs);
s.wait();
}
catch (const mkldnn::error& e)
......@@ -576,9 +801,9 @@ static void __mlir_mkldnn_softmax(size_t rank,
static void __mlir_cblas_sgemm(StaticMemRef* memRefmatA,
StaticMemRef* memRefmatB,
StaticMemRef* memRefmatC,
size_t index)
opAttrs* attrsPtr)
{
gemmAttrs gAttrs = getAttrs(index).gemmAttrs2d;
gemmAttrs gAttrs = (*attrsPtr).gemmAttrs2d;
;
cblas::cblas_sgemm(cblas::Layout::RowMajor,
gAttrs.transposeA ? cblas::Transpose::Transpose : cblas::Transpose::None,
......@@ -601,9 +826,9 @@ static void __mlir_cblas_sgemm_with_bias(StaticMemRef* memRefmatA,
StaticMemRef* memRefmatB,
StaticMemRef* memRefmatC,
StaticMemRef* memRefmatOut,
size_t index)
opAttrs* attrsPtr)
{
gemmAttrs gAttrs = getAttrs(index).gemmAttrs2d;
gemmAttrs gAttrs = (*attrsPtr).gemmAttrs2d;
auto transposeA = gAttrs.transposeA;
auto transposeB = gAttrs.transposeB;
auto m = gAttrs.m;
......@@ -719,7 +944,8 @@ static void __mlir_cblas_sgemm_with_bias(StaticMemRef* memRefmatA,
}
}
extern "C" void _mlir_ciface_callback_1_input(void* input, void* output, size_t index, OpType type)
extern "C" void
_mlir_ciface_callback_1_input(void* input, void* output, void* attrsPtr, OpType type)
{
auto unrankedMemRefInput = reinterpret_cast<UnrankedMemRef*>(input);
auto unrankedMemRefOutput = reinterpret_cast<UnrankedMemRef*>(output);
......@@ -729,14 +955,14 @@ extern "C" void _mlir_ciface_callback_1_input(void* input, void* output, size_t
__mlir_mkldnn_softmax(unrankedMemRefInput->rank,
unrankedMemRefInput->memRefDescPtr,
unrankedMemRefOutput->memRefDescPtr,
index);
static_cast<opAttrs*>(attrsPtr));
}
else if (type == OpType::AVGPOOL || type == OpType::MAXPOOL)
{
__mlir_mkldnn_pooling(unrankedMemRefInput->rank,
unrankedMemRefInput->memRefDescPtr,
unrankedMemRefOutput->memRefDescPtr,
index,
static_cast<opAttrs*>(attrsPtr),
type);
}
else if (type == OpType::AVGPOOLBACKPROP)
......@@ -744,7 +970,7 @@ extern "C" void _mlir_ciface_callback_1_input(void* input, void* output, size_t
__mlir_mkldnn_avgpoolbackprop(unrankedMemRefInput->rank,
unrankedMemRefInput->memRefDescPtr,
unrankedMemRefOutput->memRefDescPtr,
index);
static_cast<opAttrs*>(attrsPtr));
}
else
{
......@@ -753,26 +979,26 @@ extern "C" void _mlir_ciface_callback_1_input(void* input, void* output, size_t
}
extern "C" void _mlir_ciface_callback_2_inputs(
void* input0, void* input1, void* output, size_t index, OpType type)
void* input0, void* input1, void* output, void* attrsPtr, OpType type)
{
auto unrankedMemRefInput0 = reinterpret_cast<UnrankedMemRef*>(input0);
auto unrankedMemRefInput1 = reinterpret_cast<UnrankedMemRef*>(input1);
auto unrankedMemRefOutput = reinterpret_cast<UnrankedMemRef*>(output);
if (type == OpType::MAXPOOLBACKPROP)
if (type == OpType::MATMUL)
{
__mlir_mkldnn_maxpoolbackprop(unrankedMemRefInput0->rank,
unrankedMemRefInput0->memRefDescPtr,
__mlir_cblas_sgemm(unrankedMemRefInput0->memRefDescPtr,
unrankedMemRefInput1->memRefDescPtr,
unrankedMemRefOutput->memRefDescPtr,
index);
static_cast<opAttrs*>(attrsPtr));
}
else if (type == OpType::MATMUL)
else if (type == OpType::MAXPOOLBACKPROP)
{
__mlir_cblas_sgemm(unrankedMemRefInput0->memRefDescPtr,
__mlir_mkldnn_maxpoolbackprop(unrankedMemRefInput0->rank,
unrankedMemRefInput0->memRefDescPtr,
unrankedMemRefInput1->memRefDescPtr,
unrankedMemRefOutput->memRefDescPtr,
index);
static_cast<opAttrs*>(attrsPtr));
}
else
{
......@@ -781,7 +1007,7 @@ extern "C" void _mlir_ciface_callback_2_inputs(
}
extern "C" void _mlir_ciface_callback_3_inputs(
void* input0, void* input1, void* input2, void* output, size_t index, OpType type)
void* input0, void* input1, void* input2, void* output, void* attrsPtr, OpType type)
{
auto unrankedMemRefInput0 = reinterpret_cast<UnrankedMemRef*>(input0);
auto unrankedMemRefInput1 = reinterpret_cast<UnrankedMemRef*>(input1);
......@@ -794,7 +1020,16 @@ extern "C" void _mlir_ciface_callback_3_inputs(
unrankedMemRefInput1->memRefDescPtr,
unrankedMemRefInput2->memRefDescPtr,
unrankedMemRefOutput->memRefDescPtr,
index);
static_cast<opAttrs*>(attrsPtr));
}
else if (type == OpType::CONVOLUTIONBIAS)
{
__mlir_mkldnn_convbias(unrankedMemRefInput0->rank,
unrankedMemRefInput0->memRefDescPtr,
unrankedMemRefInput1->memRefDescPtr,
unrankedMemRefInput2->memRefDescPtr,
unrankedMemRefOutput->memRefDescPtr,
static_cast<opAttrs*>(attrsPtr));
}
else
{
......
......@@ -60,27 +60,29 @@ llvm::cl::opt<bool> clEnableBarePtrMemRefLowering(
llvm::cl::init(false),
llvm::cl::desc("Enable the lowering of MemRefs to LLVM bare pointers"));
void MLIRCPURuntime::run(const std::vector<MemRefArg>& args)
void MLIRCPURuntime::run(const std::vector<MemRefArg>& args, bool firstIteration)
{
// run_internal(*reinterpret_cast<std::vector<void*>*>(args), shapeVec, stridesVec);
run_internal(args);
run_internal(args, firstIteration);
}
void MLIRCPURuntime::run_internal(const std::vector<MemRefArg>& args)
void MLIRCPURuntime::run_internal(const std::vector<MemRefArg>& args, bool firstIteration)
{
// Create an MLIR execution engine. We use a null MLIR pass manager for now to make sure we
// don't run MLIR passes that were already run. We also pass a default transformer created with
// the default or user-provided optimization level.
if (!m_engine)
{
auto llvmTransformer = mlir::makeOptimizingTransformer(
MLIRCPUBackend::mlirOptLevel, /*sizeLevel=*/0, MLIRCPUBackend::targetMachine.get());
auto maybeEngine = mlir::ExecutionEngine::create(
m_module.get(), llvmTransformer, MLIRCPUBackend::mlirOptLevel);
NGRAPH_CHECK(maybeEngine, "failed to construct an execution engine");
m_engine = std::move(maybeEngine.get());
}
bindArguments(args);
execute();
execute(firstIteration);
cleanup();
}
......@@ -90,7 +92,8 @@ void MLIRCPURuntime::bindArguments(const std::vector<MemRefArg>& args)
{
NGRAPH_CHECK(m_module, "MLIR module is not ready.");
auto func = m_module->lookupSymbol<mlir::LLVM::LLVMFuncOp>("_mlir_ciface_main");
auto name = clEnableBarePtrMemRefLowering ? "main" : "_mlir_ciface_main";
auto func = m_module->lookupSymbol<mlir::LLVM::LLVMFuncOp>(name);
NGRAPH_CHECK(func && !func.getBlocks().empty(), "Function not found");
// Set external arguments
......@@ -138,21 +141,46 @@ void MLIRCPURuntime::bindArguments(const std::vector<MemRefArg>& args)
}
// Lowers standard dialect to LLVM dialect and uses the MLIR execution engine to execute the code.
void MLIRCPURuntime::execute()
void MLIRCPURuntime::execute(bool firstIteration)
{
// Invoke the JIT-compiled function with the arguments. Note that, for API
// uniformity reasons, it takes a list of type-erased pointers to arguments.
// Please, note that 'invoke' method is overloaded with a parameter pack version.
// Make sure the MutableArrayRef version is invoked.
if (!clEnableBarePtrMemRefLowering)
{
if (firstIteration)
{
auto invocationResult = m_engine->invoke("_mlir_ciface_callback_init");
if (clDumpObjectFile)
{
m_engine->dumpToObjectFile(clObjectFilename.empty() ? "jitted_mlir.o"
: clObjectFilename.getValue());
}
NGRAPH_CHECK(!invocationResult,
"JIT invocation of '_mlir_ciface_callback_init' failed\n");
}
auto invocationResult =
m_engine->invoke("_mlir_ciface_main", llvm::MutableArrayRef<void*>(m_invokeArgs));
if (clDumpObjectFile)
{
m_engine->dumpToObjectFile(clObjectFilename.empty() ? "jitted_mlir.o"
: clObjectFilename.getValue());
}
NGRAPH_CHECK(!invocationResult, "JIT invocation of '_mlir_ciface_main' failed\n");
}
else
{
auto invocationResult =
m_engine->invoke("main", llvm::MutableArrayRef<void*>(m_invokeArgs));
if (clDumpObjectFile)
{
m_engine->dumpToObjectFile(clObjectFilename.empty() ? "jitted_mlir.o"
: clObjectFilename.getValue());
}
NGRAPH_CHECK(!invocationResult, "JIT invocation of 'main' failed\n");
}
}
void MLIRCPURuntime::cleanup()
......
......@@ -55,14 +55,14 @@ namespace ngraph
{
public:
/// Executes a pre-compiled subgraph
void run(const std::vector<MemRefArg>& args) override;
void run(const std::vector<MemRefArg>& args, bool firstIteration) override;
private:
void run_internal(const std::vector<MemRefArg>& args);
void run_internal(const std::vector<MemRefArg>& args, bool firstIteration);
// Bind external tensors to MLIR module entry point
void bindArguments(const std::vector<MemRefArg>& args);
// Invokes an MLIR module entry point with bound arguments
void execute();
void execute(bool firstIteration);
// Cleans up allocated args
void cleanup();
......
......@@ -50,7 +50,7 @@ namespace ngraph
/// Overload with module op
void set_module(mlir::ModuleOp& module) { m_module = module; }
/// Executes a pre-compiled subgraph
virtual void run(const std::vector<MemRefArg>& args) = 0;
virtual void run(const std::vector<MemRefArg>& args, bool firstIteration) = 0;
/// Get the MLIR module that this runtime owns
mlir::OwningModuleRef& get_module() { return m_module; }
......
......@@ -136,13 +136,13 @@ namespace ngraph
mlir_backend.codegen();
// Store module into runtime, and invoke.
mlir_runtime.set_module(mlir_backend.get_module());
mlir_runtime.run(mem_ref_arg_vec);
mlir_runtime.run(mem_ref_arg_vec, true /*firstIteration*/);
}
else
{
// We have found a cached runtime, just invoke.
MLIRCPURuntime& mlir_runtime = it->second;
mlir_runtime.run(mem_ref_arg_vec);
mlir_runtime.run(mem_ref_arg_vec, false /*firstIteration*/);
}
};
......
......@@ -1191,13 +1191,15 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
if (getenv_bool("NGRAPH_MLIR") && getenv_bool("NGRAPH_MLIR_CALLBACK"))
{
if (typeid(ngraph::op::MatMul) == typeid(node) &&
node.get_input_element_type(0) == element::f32)
node.get_input_element_type(0) == element::f32 &&
node.get_input_shape(0).size() == 2 && node.get_input_shape(1).size() == 2)
{
return true;
}
if (typeid(ngraph::op::Gemm) == typeid(node) &&
node.get_input_element_type(0) == element::f32)
node.get_input_element_type(0) == element::f32 &&
node.get_input_shape(0).size() == 2 && node.get_input_shape(1).size() == 2)
{
return true;
}
......
......@@ -992,8 +992,12 @@ TEST(cpu_fusion, conv_horizontal_fusion)
auto cpu_results = execute(cpu_f, args, "CPU");
EXPECT_TRUE(test::all_close(cpu_results.at(0), int_results.at(0)));
size_t cpu_ck = count_ops_of_type<op::CompiledKernel>(cpu_f);
if (!cpu_ck)
{
size_t cpu_cb = count_ops_of_type<op::ConvolutionBias>(cpu_f);
ASSERT_EQ(cpu_cb, 1);
}
}
// ConvolutionBiasAdd relies on an in-place fused MKLDNN kernel.
......
......@@ -1052,7 +1052,12 @@ TEST(cpu_test, thread_safe_calls_convolution_2d_2items)
unset_environment("NGRAPH_CPU_CONCURRENCY");
}
TEST(cpu_test, constant_convertlayout)
// This test checks if a ConverLayout node is inserted before the ConvolutionBias node.
// Since MLIR supports ConvolutionBias through callback, the data layout conversion is done in
// callback.
// There is no ConvertLayout node when MLIR and MLIR CALLBACK are enabled.
// Thus this test is disabled with MLIR enabled.
TEST(cpu_test, MLIR_DISABLE_TEST(constant_convertlayout))
{
Shape data_shape{1, 64, 56, 56};
auto data = make_shared<op::Parameter>(element::f32, data_shape);
......
......@@ -4,13 +4,29 @@
// -----
// Convbias Op
// CHECK-LABEL: func @simple_convbias
// CHECK-DAG: %[[GA0:.*]] = llvm.mlir.addressof @{{[a-zA-Z_][a-zA-Z0-9_]*}} : !llvm<"{ i8, [3 x i64], [3 x i64], [3 x i64], [3 x i64] }*">
// CHECK-DAG: %[[C0:.*]] = constant {{[0-9]+}} : i64
// CHECK-DAG: %[[MC0:.*]] = memref_cast %arg0 : memref<1x1x3x3xf32> to memref<*xf32>
// CHECK-DAG: %[[MC1:.*]] = memref_cast %arg1 : memref<1x1x3x3xf32> to memref<*xf32>
// CHECK-DAG: %[[MC2:.*]] = memref_cast %arg2 : memref<1xf32> to memref<*xf32>
// CHECK-DAG: %[[MC3:.*]] = memref_cast %arg3 : memref<1x1x1x1xf32> to memref<*xf32>
// CHECK: call @callback_3_inputs(%[[MC0]], %[[MC1]], %[[MC2]], %[[MC3]], %[[GA0]], %[[C0]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, memref<*xf32>, !llvm<"{ i8, [3 x i64], [3 x i64], [3 x i64], [3 x i64] }*">, i64) -> ()
func @simple_convbias(%arg0: !ng.tensor<1x1x3x3xf32>, %arg1: !ng.tensor<1x1x3x3xf32>, %arg2: !ng.tensor<1xf32>) -> !ng.tensor<1x1x1x1xf32> {
%0 = "ng.convBias"(%arg0, %arg1, %arg2) {dilation = [1, 1], padAbove = [0, 0], padBelow = [0, 0], strides = [1, 1], withRelu = false} : (!ng.tensor<1x1x3x3xf32>, !ng.tensor<1x1x3x3xf32>, !ng.tensor<1xf32>) -> !ng.tensor<1x1x1x1xf32>
"ng.return"(%0) : (!ng.tensor<1x1x1x1xf32>) -> ()
}
// -----
// Softmax Op
// CHECK-LABEL: func @simple_softmax
// CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: %0 = memref_cast %arg0 : memref<2x3xf32> to memref<*xf32>
// CHECK: %1 = memref_cast %arg2 : memref<2x3xf32> to memref<*xf32>
// CHECK: call @callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
// CHECK-DAG: %[[GA0:.*]] = llvm.mlir.addressof @{{[a-zA-Z_][a-zA-Z0-9_]*}} : !llvm<"{ i8, [3 x i64], [3 x i64], [3 x i64], [3 x i64] }*">
// CHECK-DAG: %[[C0:.*]] = constant {{[0-9]+}} : i64
// CHECK-DAG: %[[MC0:.*]] = memref_cast %arg0 : memref<2x3xf32> to memref<*xf32>
// CHECK-DAG: %[[MC1:.*]] = memref_cast %arg2 : memref<2x3xf32> to memref<*xf32>
// CHECK: call @callback_1_input(%[[MC0]], %[[MC1]], %[[GA0]], %[[C0]]) : (memref<*xf32>, memref<*xf32>, !llvm<"{ i8, [3 x i64], [3 x i64], [3 x i64], [3 x i64] }*">, i64) -> ()
func @simple_softmax(%arg0: !ng.tensor<2x3xf32>, %arg1: !ng.tensor<1x!ng.i64>) -> !ng.tensor<2x3xf32> {
%0 = "ng.softmax"(%arg0) {axes = [0]} : (!ng.tensor<2x3xf32>) -> !ng.tensor<2x3xf32>
"ng.return"(%0) : (!ng.tensor<2x3xf32>) -> ()
......@@ -20,13 +36,13 @@ func @simple_softmax(%arg0: !ng.tensor<2x3xf32>, %arg1: !ng.tensor<1x!ng.i64>) -
// Gemm Op
// CHECK-LABEL: func @simple_gemm
// CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: %0 = memref_cast %arg0 : memref<3x6xf32> to memref<*xf32>
// CHECK: %1 = memref_cast %arg1 : memref<6x4xf32> to memref<*xf32>
// CHECK: %2 = memref_cast %arg2 : memref<3x4xf32> to memref<*xf32>
// CHECK: %3 = memref_cast %arg3 : memref<3x4xf32> to memref<*xf32>
// CHECK: call @callback_3_inputs(%0, %1, %2, %3, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, memref<*xf32>, i64, i64) -> ()
// CHECK-DAG: %[[GA0:.*]] = llvm.mlir.addressof @{{[a-zA-Z_][a-zA-Z0-9_]*}} : !llvm<"{ i8, [3 x i64], [3 x i64], [3 x i64], [3 x i64] }*">
// CHECK-DAG: %[[C0:.*]] = constant {{[0-9]+}} : i64
// CHECK-DAG: %[[MC0:.*]] = memref_cast %arg0 : memref<3x6xf32> to memref<*xf32>
// CHECK-DAG: %[[MC1:.*]] = memref_cast %arg1 : memref<6x4xf32> to memref<*xf32>
// CHECK-DAG: %[[MC2:.*]] = memref_cast %arg2 : memref<3x4xf32> to memref<*xf32>
// CHECK-DAG: %[[MC3:.*]] = memref_cast %arg3 : memref<3x4xf32> to memref<*xf32>
// CHECK: call @callback_3_inputs(%[[MC0]], %[[MC1]], %[[MC2]], %[[MC3]], %[[GA0]], %[[C0]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, memref<*xf32>, !llvm<"{ i8, [3 x i64], [3 x i64], [3 x i64], [3 x i64] }*">, i64) -> ()
func @simple_gemm(%arg0: !ng.tensor<3x6xf32>, %arg1: !ng.tensor<6x4xf32>, %arg2: !ng.tensor<3x4xf32>) -> !ng.tensor<3x4xf32> {
%0 = "ng.gemm"(%arg0, %arg1, %arg2) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = false, transB = false} : (!ng.tensor<3x6xf32>, !ng.tensor<6x4xf32>, !ng.tensor<3x4xf32>) -> !ng.tensor<3x4xf32>
"ng.return"(%0) : (!ng.tensor<3x4xf32>) -> ()
......@@ -36,12 +52,12 @@ func @simple_gemm(%arg0: !ng.tensor<3x6xf32>, %arg1: !ng.tensor<6x4xf32>, %arg2:
// MatMul Op
// CHECK-LABEL: func @simple_matmul
// CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: %0 = memref_cast %arg0 : memref<3x2xf32> to memref<*xf32>
// CHECK: %1 = memref_cast %arg1 : memref<2x3xf32> to memref<*xf32>
// CHECK: %2 = memref_cast %arg2 : memref<2x2xf32> to memref<*xf32>
// CHECK: call @callback_2_inputs(%0, %1, %2, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, i64, i64) -> ()
// CHECK-DAG: %[[GA0:.*]] = llvm.mlir.addressof @{{[a-zA-Z_][a-zA-Z0-9_]*}} : !llvm<"{ i8, [3 x i64], [3 x i64], [3 x i64], [3 x i64] }*">
// CHECK-DAG: %[[C0:.*]] = constant {{[0-9]+}} : i64
// CHECK-DAG: %[[MC0:.*]] = memref_cast %arg0 : memref<3x2xf32> to memref<*xf32>
// CHECK-DAG: %[[MC1:.*]] = memref_cast %arg1 : memref<2x3xf32> to memref<*xf32>
// CHECK-DAG: %[[MC2:.*]] = memref_cast %arg2 : memref<2x2xf32> to memref<*xf32>
// CHECK: call @callback_2_inputs(%[[MC0]], %[[MC1]], %[[MC2]], %[[GA0]], %[[C0]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, !llvm<"{ i8, [3 x i64], [3 x i64], [3 x i64], [3 x i64] }*">, i64) -> ()
func @simple_matmul(%arg0: !ng.tensor<3x2xf32>, %arg1: !ng.tensor<2x3xf32>) -> !ng.tensor<2x2xf32> {
%0 = "ng.matmul"(%arg0, %arg1) {transposeA = true, transposeB = true} : (!ng.tensor<3x2xf32>, !ng.tensor<2x3xf32>) -> !ng.tensor<2x2xf32>
"ng.return"(%0) : (!ng.tensor<2x2xf32>) -> ()
......@@ -51,11 +67,11 @@ func @simple_matmul(%arg0: !ng.tensor<3x2xf32>, %arg1: !ng.tensor<2x3xf32>) -> !
// AvePool Op
// CHECK-LABEL: func @simple_avgpool
// CHECK: %0 = memref_cast %arg0 : memref<2x1x3x3xf32> to memref<*xf32>
// CHECK: %1 = memref_cast %arg1 : memref<2x1x3x3xf32> to memref<*xf32>
// CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: call @callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
// CHECK-DAG: %[[GA0:.*]] = llvm.mlir.addressof @{{[a-zA-Z_][a-zA-Z0-9_]*}} : !llvm<"{ i8, [3 x i64], [3 x i64], [3 x i64], [3 x i64] }*">
// CHECK-DAG: %[[C0:.*]] = constant {{[0-9]+}} : i64
// CHECK-DAG: %[[MC0:.*]] = memref_cast %arg0 : memref<2x1x3x3xf32> to memref<*xf32>
// CHECK-DAG: %[[MC1:.*]] = memref_cast %arg1 : memref<2x1x3x3xf32> to memref<*xf32>
// CHECK: call @callback_1_input(%[[MC0]], %[[MC1]], %[[GA0]], %[[C0]]) : (memref<*xf32>, memref<*xf32>, !llvm<"{ i8, [3 x i64], [3 x i64], [3 x i64], [3 x i64] }*">, i64) -> ()
func @simple_avgpool(%arg0: !ng.tensor<2x1x3x3xf32>) -> !ng.tensor<2x1x3x3xf32> {
%0 = "ng.avgPool"(%arg0) {includePadding = true, padAbove = [1, 1], padBelow = [0, 0], windowMovementStrides = [1, 1], windowShape = [2, 2]} : (!ng.tensor<2x1x3x3xf32>) -> !ng.tensor<2x1x3x3xf32>
"ng.return"(%0) : (!ng.tensor<2x1x3x3xf32>) -> ()
......@@ -65,11 +81,11 @@ func @simple_avgpool(%arg0: !ng.tensor<2x1x3x3xf32>) -> !ng.tensor<2x1x3x3xf32>
// AvgPoolBackprop Op
// CHECK-LABEL: func @simple_avgpoolbackprop
// CHECK: %0 = memref_cast %arg0 : memref<2x2x2x2xf32> to memref<*xf32>
// CHECK: %1 = memref_cast %arg1 : memref<2x2x3x3xf32> to memref<*xf32>
// CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: call @callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
// CHECK-DAG: %[[GA0:.*]] = llvm.mlir.addressof @{{[a-zA-Z_][a-zA-Z0-9_]*}} : !llvm<"{ i8, [3 x i64], [3 x i64], [3 x i64], [3 x i64] }*">
// CHECK-DAG: %[[C0:.*]] = constant {{[0-9]+}} : i64
// CHECK-DAG: %[[MC0:.*]] = memref_cast %arg0 : memref<2x2x2x2xf32> to memref<*xf32>
// CHECK-DAG: %[[MC1:.*]] = memref_cast %arg1 : memref<2x2x3x3xf32> to memref<*xf32>
// CHECK: call @callback_1_input(%[[MC0]], %[[MC1]], %[[GA0]], %[[C0]]) : (memref<*xf32>, memref<*xf32>, !llvm<"{ i8, [3 x i64], [3 x i64], [3 x i64], [3 x i64] }*">, i64) -> ()
func @simple_avgpoolbackprop(%arg0: !ng.tensor<2x2x2x2xf32>) -> !ng.tensor<2x2x3x3xf32> {
%0 = "ng.avgPoolBackprop"(%arg0) {forwardArgShape = [2, 2, 3, 3], includePadding = false, padAbove = [0, 0], padBelow = [0, 0], windowMovementStrides = [1, 1], windowShape = [2, 2]} : (!ng.tensor<2x2x2x2xf32>) -> !ng.tensor<2x2x3x3xf32>
"ng.return"(%0) : (!ng.tensor<2x2x3x3xf32>) -> ()
......@@ -79,11 +95,11 @@ func @simple_avgpoolbackprop(%arg0: !ng.tensor<2x2x2x2xf32>) -> !ng.tensor<2x2x3
// MaxPool Op
// CHECK-LABEL: func @simple_maxpool
// CHECK: %0 = memref_cast %arg0 : memref<64x3x7x8x10xf32> to memref<*xf32>
// CHECK: %1 = memref_cast %arg1 : memref<64x3x9x6x5xf32> to memref<*xf32>
// CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: call @callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
// CHECK-DAG: %[[GA0:.*]] = llvm.mlir.addressof @{{[a-zA-Z_][a-zA-Z0-9_]*}} : !llvm<"{ i8, [3 x i64], [3 x i64], [3 x i64], [3 x i64] }*">
// CHECK-DAG: %[[C0:.*]] = constant {{[0-9]+}} : i64
// CHECK-DAG: %[[MC0:.*]] = memref_cast %arg0 : memref<64x3x7x8x10xf32> to memref<*xf32>
// CHECK-DAG: %[[MC1:.*]] = memref_cast %arg1 : memref<64x3x9x6x5xf32> to memref<*xf32>
// CHECK: call @callback_1_input(%[[MC0]], %[[MC1]], %[[GA0]], %[[C0]]) : (memref<*xf32>, memref<*xf32>, !llvm<"{ i8, [3 x i64], [3 x i64], [3 x i64], [3 x i64] }*">, i64) -> ()
func @simple_maxpool(%arg0: !ng.tensor<64x3x7x8x10xf32>) -> !ng.tensor<64x3x9x6x5xf32> {
%0 = "ng.maxPool"(%arg0) {padAbove = [6, 4, 5], padBelow = [5, 6, 4], windowMovementStrides = [2, 3, 4], windowShape = [2, 3, 2]} : (!ng.tensor<64x3x7x8x10xf32>) -> !ng.tensor<64x3x9x6x5xf32>
"ng.return"(%0) : (!ng.tensor<64x3x9x6x5xf32>) -> ()
......@@ -93,12 +109,12 @@ func @simple_maxpool(%arg0: !ng.tensor<64x3x7x8x10xf32>) -> !ng.tensor<64x3x9x6x
// MaxPoolBackprop Op
// CHECK-LABEL: func @simple_maxpoolbackprop
// CHECK: %0 = memref_cast %arg0 : memref<2x2x5x5xf32> to memref<*xf32>
// CHECK: %1 = memref_cast %arg1 : memref<2x2x4x3xf32> to memref<*xf32>
// CHECK: %2 = memref_cast %arg2 : memref<2x2x5x5xf32> to memref<*xf32>
// CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: call @callback_2_inputs(%0, %1, %2, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, i64, i64) -> ()
// CHECK-DAG: %[[GA0:.*]] = llvm.mlir.addressof @{{[a-zA-Z_][a-zA-Z0-9_]*}} : !llvm<"{ i8, [3 x i64], [3 x i64], [3 x i64], [3 x i64] }*">
// CHECK-DAG: %[[C0:.*]] = constant {{[0-9]+}} : i64
// CHECK-DAG: %[[MC0:.*]] = memref_cast %arg0 : memref<2x2x5x5xf32> to memref<*xf32>
// CHECK-DAG: %[[MC1:.*]] = memref_cast %arg1 : memref<2x2x4x3xf32> to memref<*xf32>
// CHECK-DAG: %[[MC2:.*]] = memref_cast %arg2 : memref<2x2x5x5xf32> to memref<*xf32>
// CHECK: call @callback_2_inputs(%[[MC0]], %[[MC1]], %[[MC2]], %[[GA0]], %[[C0]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, !llvm<"{ i8, [3 x i64], [3 x i64], [3 x i64], [3 x i64] }*">, i64) -> ()
func @simple_maxpoolbackprop(%arg0: !ng.tensor<2x2x5x5xf32>, %arg1: !ng.tensor<2x2x4x3xf32>) -> !ng.tensor<2x2x5x5xf32> {
%0 = "ng.maxPoolBackprop"(%arg0, %arg1) {padAbove = [0, 0], padBelow = [0, 0], windowMovementStrides = [1, 1], windowShape = [2, 3]} : (!ng.tensor<2x2x5x5xf32>, !ng.tensor<2x2x4x3xf32>) -> !ng.tensor<2x2x5x5xf32>
"ng.return"(%0) : (!ng.tensor<2x2x5x5xf32>) -> ()
......
......@@ -250,8 +250,8 @@ func @depthToSpace(%arg0: !ng.tensor<1x8x2x2xf32>) -> !ng.tensor<1x2x4x4xf32>
//CHECK-LABEL: func @convBias
func @convBias(%arg0: !ng.tensor<1x3x2xf32>, %arg1: !ng.tensor<2x3x1xf32>, %arg2: !ng.tensor<2xf32>) -> (!ng.tensor<1x2x2xf32>)
{
//CHECK: %{{.*}} = "ng.convBias"(%{{.*}}, %{{.*}}, %{{.*}}) {padAbove = [0], padBelow = [0], strides = [1]} : (!ng.tensor<1x3x2xf32>, !ng.tensor<2x3x1xf32>, !ng.tensor<2xf32>) -> !ng.tensor<1x2x2xf32>
%0 = "ng.convBias"(%arg0, %arg1, %arg2) {padAbove=[0], padBelow=[0], strides=[1]}
//CHECK: %{{.*}} = "ng.convBias"(%{{.*}}, %{{.*}}, %{{.*}}) {dilation = [1], padAbove = [0], padBelow = [0], strides = [1]} : (!ng.tensor<1x3x2xf32>, !ng.tensor<2x3x1xf32>, !ng.tensor<2xf32>) -> !ng.tensor<1x2x2xf32>
%0 = "ng.convBias"(%arg0, %arg1, %arg2) {dilation=[1], padAbove=[0], padBelow=[0], strides=[1]}
: (!ng.tensor<1x3x2xf32>, !ng.tensor<2x3x1xf32>, !ng.tensor<2xf32>) -> !ng.tensor<1x2x2xf32>
"ng.return"(%0) : (!ng.tensor<1x2x2xf32>) -> ()
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment