Unverified Commit e5258c0a authored by Diego Caballero's avatar Diego Caballero Committed by GitHub

[MLIR] Update MLIR repo (#4335)

* [MLIR] Update MLIR repo

* Nagy's fix

* Changes related to mlir-opt

* Update MLIR commit

* Update MLIR commit and callbacks.

* Disable 'noalias' attribute.

It will be re-introduced in a follow-up commit.

* Remove '__mlir' prefix in callback test

* Address feedback

* Fix EDSC includes

* Move MLIR repo forward

* Update type converter code

* Address feedback
Co-authored-by: 's avatarAmy Zhuang <amyzhuang97@gmail.com>
parent fdd8db66
...@@ -19,7 +19,7 @@ include(ExternalProject) ...@@ -19,7 +19,7 @@ include(ExternalProject)
set(MLIR_LLVM_REPO_URL https://github.com/llvm/llvm-project.git) set(MLIR_LLVM_REPO_URL https://github.com/llvm/llvm-project.git)
# Change these commit IDs to move to latest stable versions # Change these commit IDs to move to latest stable versions
set(MLIR_LLVM_COMMIT_ID 96400ae) set(MLIR_LLVM_COMMIT_ID 376c6853)
# MLIR environment variables. Some of them are used by LIT tool. # MLIR environment variables. Some of them are used by LIT tool.
......
...@@ -26,7 +26,6 @@ ...@@ -26,7 +26,6 @@
#include <llvm/ADT/DenseSet.h> #include <llvm/ADT/DenseSet.h>
#include <map> #include <map>
#include <mlir/EDSC/Builders.h> #include <mlir/EDSC/Builders.h>
#include <mlir/EDSC/Helpers.h>
#include <mlir/EDSC/Intrinsics.h> #include <mlir/EDSC/Intrinsics.h>
#include <mlir/IR/AffineExpr.h> #include <mlir/IR/AffineExpr.h>
#include <mlir/IR/IntegerSet.h> #include <mlir/IR/IntegerSet.h>
......
...@@ -194,7 +194,8 @@ void MLIRCPUBackend::lowerNgDialect() ...@@ -194,7 +194,8 @@ void MLIRCPUBackend::lowerNgDialect()
void MLIRCPUBackend::lowerStandardDialect() void MLIRCPUBackend::lowerStandardDialect()
{ {
mlir::PassManager pm(&m_context); mlir::PassManager pm(&m_context);
pm.addPass(mlir::createLowerToLLVMPass()); pm.addPass(mlir::createLowerToLLVMPass(
/*useAlloca=*/false, /*useBarePtrCallConv=*/false, /*emitCWrappers=*/true));
// Apply any generic pass manager command line options. // Apply any generic pass manager command line options.
mlir::applyPassManagerCLOptions(pm); mlir::applyPassManagerCLOptions(pm);
......
...@@ -28,9 +28,9 @@ ...@@ -28,9 +28,9 @@
#include <llvm/ADT/DenseSet.h> #include <llvm/ADT/DenseSet.h>
#include <llvm/Support/Debug.h> #include <llvm/Support/Debug.h>
#include <mlir/EDSC/Builders.h> #include <mlir/Dialect/AffineOps/EDSC/Builders.h>
#include <mlir/EDSC/Helpers.h> #include <mlir/Dialect/AffineOps/EDSC/Intrinsics.h>
#include <mlir/EDSC/Intrinsics.h> #include <mlir/Dialect/StandardOps/EDSC/Intrinsics.h>
#include <mlir/IR/AffineExpr.h> #include <mlir/IR/AffineExpr.h>
#include <mlir/IR/Function.h> #include <mlir/IR/Function.h>
#include <mlir/IR/IntegerSet.h> #include <mlir/IR/IntegerSet.h>
...@@ -51,11 +51,10 @@ namespace ...@@ -51,11 +51,10 @@ namespace
{ {
using namespace mlir; using namespace mlir;
using namespace mlir::edsc; using namespace mlir::edsc;
using namespace mlir::edsc::intrinsics;
using namespace mlir::edsc::op; using namespace mlir::edsc::op;
using namespace ngraph::runtime; using namespace ngraph::runtime;
using namespace ngraph::runtime::ngmlir; using namespace ngraph::runtime::ngmlir;
// Index notation to generate standard (i.e., non-affine) loads and stores.
using StdIndexedValue = TemplatedIndexedValue<intrinsics::std_load, intrinsics::std_store>;
class DialectLoweringPass; class DialectLoweringPass;
...@@ -215,9 +214,37 @@ namespace ...@@ -215,9 +214,37 @@ namespace
NGraphTypeConverter() NGraphTypeConverter()
: TypeConverter() : TypeConverter()
{ {
} // TODO(dcaballe): split this into independent conversion patterns when there is a
// way to check if a type is valid in Std dialect.
addConversion([this](Type type) -> Type {
if (auto tensorType = type.dyn_cast<NGTensorType>())
{
// Convert NGTensorType to Std MemRefType directly instead of going to Std
// TensorType. This may change in the future.
return MemRefType::get(tensorType.getShape(),
convertType(tensorType.getElementType()),
{/* no map used */},
0);
}
if (auto floatType = type.dyn_cast<NGFloatType>())
{
// Float types are already std type.
return floatType;
}
if (auto intType = type.dyn_cast<NGIntegerType>())
{
return mlir::IntegerType::get(intType.getWidth(), intType.getContext());
}
if (auto boolType = type.dyn_cast<NGBoolType>())
{
return mlir::IntegerType::get(1 /* width */, boolType.getContext());
}
Type convertType(Type t) override; // Do not assert/NGRAPH_CHECK here. Type convertion infra expects `convertType` to
// return the input type if the type is not supported.
return type;
});
}
}; };
/// Dialect Lowering Pass to affine ops /// Dialect Lowering Pass to affine ops
...@@ -317,7 +344,8 @@ namespace ...@@ -317,7 +344,8 @@ namespace
// TODO: Encode no alias attribute as part of the function signature conversion or as a // TODO: Encode no alias attribute as part of the function signature conversion or as a
// separate rewrite pattern. Retrieve new function after signature conversion. // separate rewrite pattern. Retrieve new function after signature conversion.
insertNoAliasArgAttrs(); // TODO: To be enabled in follow-up commit.
// insertNoAliasArgAttrs();
} }
opAttrsVec = m_attrsVec; opAttrsVec = m_attrsVec;
...@@ -492,22 +520,22 @@ namespace ...@@ -492,22 +520,22 @@ namespace
/// Add llvm.noalias attribute to all the memref function arguments. We know that this is safe /// Add llvm.noalias attribute to all the memref function arguments. We know that this is safe
/// by nGraph op semantics. /// by nGraph op semantics.
void DialectLoweringPass::insertNoAliasArgAttrs() // void DialectLoweringPass::insertNoAliasArgAttrs()
{ //{
FuncOp func = getModule().lookupSymbol<mlir::FuncOp>(funcName); // FuncOp func = getModule().lookupSymbol<mlir::FuncOp>(funcName);
NGRAPH_CHECK(func, "FuncOp '" + funcName.str() + "' not found"); // NGRAPH_CHECK(func, "FuncOp '" + funcName.str() + "' not found");
unsigned int argIdx = 0; // unsigned int argIdx = 0;
for (auto arg : func.getArguments()) // for (auto arg : func.getArguments())
{ // {
if (arg.getType().isa<MemRefType>()) // if (arg.getType().isa<MemRefType>())
{ // {
func.setArgAttr(argIdx, "llvm.noalias", BoolAttr::get(true, &getContext())); // func.setArgAttr(argIdx, "llvm.noalias", BoolAttr::get(true, &getContext()));
} // }
++argIdx; // ++argIdx;
} // }
} //}
void DialectLoweringPass::insertDeallocs(PatternRewriter& rewriter) void DialectLoweringPass::insertDeallocs(PatternRewriter& rewriter)
{ {
...@@ -543,40 +571,6 @@ namespace ...@@ -543,40 +571,6 @@ namespace
return m_attrsVec.size() - 1; return m_attrsVec.size() - 1;
} }
// NGDialect converters
Type NGraphTypeConverter::convertType(Type type)
{
// We may need to refactor this code to a external utility if type conversion is needed
// outside of the lowering context since NGraphTypeConverter is private.
if (auto tensorType = type.dyn_cast<NGTensorType>())
{
// Convert NGTensorType to Std MemRefType directly instead of going to Std TensorType.
// This may change in the future.
return MemRefType::get(tensorType.getShape(),
convertType(tensorType.getElementType()),
{/* no map used */},
0);
}
if (auto floatType = type.dyn_cast<NGFloatType>())
{
// Float types are already std type.
return floatType;
}
if (auto intType = type.dyn_cast<NGIntegerType>())
{
return mlir::IntegerType::get(intType.getWidth(), intType.getContext());
}
if (auto boolType = type.dyn_cast<NGBoolType>())
{
return mlir::IntegerType::get(1 /* width */, boolType.getContext());
}
// Do not assert/NGRAPH_CHECK here. Type convertion infra expects `convertType` to return
// the input type if the type is not supported.
return type;
}
#define REWRITER(OP) \ #define REWRITER(OP) \
PatternMatchResult OP##Conversion::matchAndRewrite( \ PatternMatchResult OP##Conversion::matchAndRewrite( \
Operation* op, ArrayRef<Value> operands, ConversionPatternRewriter& rewriter) const Operation* op, ArrayRef<Value> operands, ConversionPatternRewriter& rewriter) const
...@@ -680,15 +674,15 @@ namespace ...@@ -680,15 +674,15 @@ namespace
ScopedContext scope(rewriter, loc); ScopedContext scope(rewriter, loc);
// Views // Views
MemRefView vRes(result), vLHS(lhs); MemRefBoundsCapture vRes(result), vLHS(lhs);
// Index Values // Index Values
IndexedValue iRes(result), iLHS(lhs); AffineIndexedValue iRes(result), iLHS(lhs);
// Bounds Index Handles // Bounds Index Handles
auto lbs = vLHS.getLbs(); auto lbs = vLHS.getLbs();
auto ubs = vLHS.getUbs(); auto ubs = vLHS.getUbs();
// Loop induction vars // Loop induction vars
auto ivs = makeIndexHandles(vLHS.rank()); auto ivs = ValueHandle::makeIndexHandles(vLHS.rank());
auto pivs = makeHandlePointers(MutableArrayRef<IndexHandle>(ivs)); auto pivs = makeHandlePointers(ivs);
// Steps // Steps
auto steps = vLHS.getSteps(); auto steps = vLHS.getSteps();
...@@ -698,7 +692,7 @@ namespace ...@@ -698,7 +692,7 @@ namespace
AffineLoopNestBuilder(pivs, lbs, ubs, steps)([&] { AffineLoopNestBuilder(pivs, lbs, ubs, steps)([&] {
ValueHandle val = iLHS(ivs); ValueHandle val = iLHS(ivs);
ValueHandle zero = createZeroConstant(elemTy); ValueHandle zero = createZeroConstant(elemTy);
iRes(ivs) = intrinsics::select(val > zero, val, zero); iRes(ivs) = std_select(val > zero, val, zero);
}); });
rewriter.replaceOp(op, {result}); rewriter.replaceOp(op, {result});
...@@ -742,36 +736,37 @@ namespace ...@@ -742,36 +736,37 @@ namespace
// res[n, k] += lhs[n, m] * rhs[m, k] // res[n, k] += lhs[n, m] * rhs[m, k]
// TODO (dcab): We currently generate a super naive loop nest. Improve loop nest layout. // TODO (dcab): We currently generate a super naive loop nest. Improve loop nest layout.
MemRefView vRes(result), vLhs(lhs), vRhs(rhs); MemRefBoundsCapture vRes(result), vLhs(lhs), vRhs(rhs);
NGRAPH_CHECK(vLhs.rank() == 2 && vRhs.rank() == 2 && vRes.rank() == 2, NGRAPH_CHECK(vLhs.rank() == 2 && vRhs.rank() == 2 && vRes.rank() == 2,
"Dot operation is only supported for 2D tensors"); "Dot operation is only supported for 2D tensors");
// Create induction variables, lower bounds, upper bounds and steps of the loop nest. // Create induction variables, lower bounds, upper bounds and steps of the loop nest.
// It's important to note that MemRefView priovides lb/ub/step info is "reverse order", // It's important to note that MemRefBoundsCapture priovides lb/ub/step info is "reverse
// i.e., fastest varying dimension is the last one, slowest varying dimention is the first // order", i.e., fastest varying dimension is the last one, slowest varying dimention is the
// one. // first one.
IndexHandle n, m, k; auto indexType = IndexType::get(rewriter.getContext());
ValueHandle n(indexType), m(indexType), k(indexType);
unsigned nDim = vLhs.fastestVarying() - 1; unsigned nDim = vLhs.fastestVarying() - 1;
unsigned mDim = vRhs.fastestVarying(); unsigned mDim = vRhs.fastestVarying();
unsigned kDim = vRhs.fastestVarying(); unsigned kDim = vRhs.fastestVarying();
IndexHandle nLb(vLhs.lb(nDim)), mLb(vLhs.lb(mDim)), kLb(vRhs.lb(kDim)); ValueHandle nLb(vLhs.lb(nDim)), mLb(vLhs.lb(mDim)), kLb(vRhs.lb(kDim));
IndexHandle nUb(vLhs.ub(nDim)), mUb(vLhs.ub(mDim)), kUb(vRhs.ub(kDim)); ValueHandle nUb(vLhs.ub(nDim)), mUb(vLhs.ub(mDim)), kUb(vRhs.ub(kDim));
int64_t nStep = vLhs.step(nDim), mStep = vLhs.step(mDim), kStep = vRhs.step(kDim); int64_t nStep = vLhs.step(nDim), mStep = vLhs.step(mDim), kStep = vRhs.step(kDim);
// Constants and indexed values to be used inside the loop nest. // Constants and indexed values to be used inside the loop nest.
IndexedValue iRes(result), iLhs(lhs), iRhs(rhs); AffineIndexedValue iRes(result), iLhs(lhs), iRhs(rhs);
ValueHandle zeroInit(rewriter.create<ConstantOp>(loc, rewriter.getZeroAttr(elemTy))); ValueHandle zeroInit(rewriter.create<ConstantOp>(loc, rewriter.getZeroAttr(elemTy)));
{ {
IndexHandle n, k; ValueHandle n(indexType), k(indexType);
LoopBuilder::makeAffine(&n, nLb, nUb, nStep)([&] { makeAffineLoopBuilder(&n, nLb, nUb, nStep)([&] {
LoopBuilder::makeAffine(&k, kLb, kUb, kStep)([&] { iRes(n, k) = zeroInit; }); makeAffineLoopBuilder(&k, kLb, kUb, kStep)([&] { iRes(n, k) = zeroInit; });
}); });
} }
LoopBuilder::makeAffine(&n, nLb, nUb, nStep)([&] { makeAffineLoopBuilder(&n, nLb, nUb, nStep)([&] {
LoopBuilder::makeAffine(&m, mLb, mUb, mStep)([&] { makeAffineLoopBuilder(&m, mLb, mUb, mStep)([&] {
LoopBuilder::makeAffine(&k, kLb, kUb, kStep)( makeAffineLoopBuilder(&k, kLb, kUb, kStep)(
[&] { iRes(n, k) += iLhs(n, m) * iRhs(m, k); }); [&] { iRes(n, k) += iLhs(n, m) * iRhs(m, k); });
}); });
}); });
...@@ -792,13 +787,13 @@ namespace ...@@ -792,13 +787,13 @@ namespace
NGRAPH_CHECK(result, "Unexpected null result in ConcatOp"); NGRAPH_CHECK(result, "Unexpected null result in ConcatOp");
// Create view to write into result. // Create view to write into result.
MemRefView vRes(result); MemRefBoundsCapture vRes(result);
auto rank = vRes.rank(); auto rank = vRes.rank();
// For each operand, generate a separate loop to copy into the target slice of "result". // For each operand, generate a separate loop to copy into the target slice of "result".
// We'll keep track of the slice offsets via concatenation_axis_pos. // We'll keep track of the slice offsets via concatenation_axis_pos.
auto concatenationAxis = concat.concatenation_axis().getSExtValue(); auto concatenationAxis = concat.concatenation_axis().getSExtValue();
IndexHandle concatenationAxisPos(index_type(0)); Value concatenationAxisPos(std_constant_index(0));
for (auto& operand : operands) for (auto& operand : operands)
{ {
...@@ -817,7 +812,7 @@ namespace ...@@ -817,7 +812,7 @@ namespace
// [i_(r-2)][i_(r-1)] // [i_(r-2)][i_(r-1)]
// := // :=
// operand[i_0][i_1]...[i_(r-2)][i_(r-1)] // operand[i_0][i_1]...[i_(r-2)][i_(r-1)]
MemRefView vOperand(operand); MemRefBoundsCapture vOperand(operand);
NGRAPH_CHECK(vOperand.rank() == rank, "Unexpected rank mismatch"); NGRAPH_CHECK(vOperand.rank() == rank, "Unexpected rank mismatch");
llvm::SmallVector<ValueHandle, 5> indexVars; llvm::SmallVector<ValueHandle, 5> indexVars;
...@@ -825,9 +820,10 @@ namespace ...@@ -825,9 +820,10 @@ namespace
llvm::SmallVector<ValueHandle, 5> indexVarLbs; llvm::SmallVector<ValueHandle, 5> indexVarLbs;
llvm::SmallVector<ValueHandle, 5> indexVarUbs; llvm::SmallVector<ValueHandle, 5> indexVarUbs;
llvm::SmallVector<int64_t, 5> indexVarSteps; llvm::SmallVector<int64_t, 5> indexVarSteps;
auto indexType = IndexType::get(rewriter.getContext());
for (int i = 0; i < rank; i++) for (int i = 0; i < rank; i++)
{ {
indexVars.push_back(IndexHandle()); indexVars.push_back(ValueHandle(indexType));
indexVarPtrs.push_back(&(indexVars.back())); indexVarPtrs.push_back(&(indexVars.back()));
indexVarLbs.push_back(vOperand.lb(i)); indexVarLbs.push_back(vOperand.lb(i));
indexVarUbs.push_back(vOperand.ub(i)); indexVarUbs.push_back(vOperand.ub(i));
...@@ -835,15 +831,15 @@ namespace ...@@ -835,15 +831,15 @@ namespace
} }
AffineLoopNestBuilder(indexVarPtrs, indexVarLbs, indexVarUbs, indexVarSteps)([&] { AffineLoopNestBuilder(indexVarPtrs, indexVarLbs, indexVarUbs, indexVarSteps)([&] {
IndexedValue ivRes(result); AffineIndexedValue ivRes(result);
IndexedValue ivOperand(operand); AffineIndexedValue ivOperand(operand);
// On the LHS of the assignment, adjust the index for the concatenation axis. // On the LHS of the assignment, adjust the index for the concatenation axis.
llvm::SmallVector<ValueHandle, 5> resIndexHandles; llvm::SmallVector<ValueHandle, 5> resIndexHandles;
for (int i = 0; i < rank; i++) for (int i = 0; i < rank; i++)
{ {
resIndexHandles.push_back(i == concatenationAxis resIndexHandles.push_back(i == concatenationAxis
? indexVars[i] + concatenationAxisPos ? indexVars[i] + ValueHandle(concatenationAxisPos)
: indexVars[i]); : indexVars[i]);
} }
...@@ -851,11 +847,11 @@ namespace ...@@ -851,11 +847,11 @@ namespace
}); });
// Move up concatenation_axis_pos for the next operand. // Move up concatenation_axis_pos for the next operand.
concatenationAxisPos = concatenationAxisPos + vOperand.ub(concatenationAxis); concatenationAxisPos =
ValueHandle(concatenationAxisPos) + vOperand.ub(concatenationAxis);
} }
rewriter.replaceOp(op, {result}); rewriter.replaceOp(op, {result});
return matchSuccess(); return matchSuccess();
} }
...@@ -874,14 +870,13 @@ namespace ...@@ -874,14 +870,13 @@ namespace
auto axis = gatherOp.axis().getSExtValue(); auto axis = gatherOp.axis().getSExtValue();
// Create view to write into result. // Create view to write into result.
MemRefView vRes(result), vParams(params), vIndices(indices); MemRefBoundsCapture vRes(result), vParams(params), vIndices(indices);
// Indexed Values // Indexed Values
IndexedValue iRes(result), iIndices(indices); AffineIndexedValue iRes(result), iIndices(indices);
StdIndexedValue iParams(params); StdIndexedValue iParams(params);
// Construct outer loop for params dims. Exclude the axis dim. // Construct outer loop for params dims. Exclude the axis dim.
SmallVector<ValueHandle, 4> paramsLbs, paramsUbs; SmallVector<ValueHandle, 4> paramsLbs, paramsUbs, paramsIVs;
SmallVector<IndexHandle, 4> paramsIVs;
SmallVector<int64_t, 4> paramsSteps; SmallVector<int64_t, 4> paramsSteps;
SmallVector<ValueHandle*, 4> paramsIVPtrs; SmallVector<ValueHandle*, 4> paramsIVPtrs;
for (auto i = 0; i < vParams.rank(); i++) for (auto i = 0; i < vParams.rank(); i++)
...@@ -889,8 +884,8 @@ namespace ...@@ -889,8 +884,8 @@ namespace
// skip gather axis // skip gather axis
if (i == axis) if (i == axis)
continue; continue;
paramsLbs.push_back(IndexHandle(vParams.lb(i))); paramsLbs.push_back(vParams.lb(i));
paramsUbs.push_back(IndexHandle(vParams.ub(i))); paramsUbs.push_back(vParams.ub(i));
paramsSteps.push_back(vParams.step(i)); paramsSteps.push_back(vParams.step(i));
} }
NGRAPH_CHECK(paramsLbs.size() == vParams.rank() - 1 && NGRAPH_CHECK(paramsLbs.size() == vParams.rank() - 1 &&
...@@ -898,17 +893,17 @@ namespace ...@@ -898,17 +893,17 @@ namespace
paramsSteps.size() == paramsLbs.size(), paramsSteps.size() == paramsLbs.size(),
"Incorrect loop nest bounds size for gather params"); "Incorrect loop nest bounds size for gather params");
paramsIVs = makeIndexHandles(vParams.rank() - 1); paramsIVs = ValueHandle::makeIndexHandles(vParams.rank() - 1);
paramsIVPtrs = makeHandlePointers(MutableArrayRef<IndexHandle>(paramsIVs)); paramsIVPtrs = makeHandlePointers(paramsIVs);
auto indicesLbs = vIndices.getLbs(); auto indicesLbs = vIndices.getLbs();
auto indicesUbs = vIndices.getUbs(); auto indicesUbs = vIndices.getUbs();
auto indicesSteps = vIndices.getSteps(); auto indicesSteps = vIndices.getSteps();
auto indicesIVs = makeIndexHandles(vIndices.rank()); auto indicesIVs = ValueHandle::makeIndexHandles(vIndices.rank());
auto indicesIVPtrs = makeHandlePointers(MutableArrayRef<IndexHandle>(indicesIVs)); auto indicesIVPtrs = makeHandlePointers(indicesIVs);
SmallVector<IndexHandle, 8> paramsIndices, resIndices; SmallVector<ValueHandle, 8> paramsIndices, resIndices;
// Make sure we are going to create loops // Make sure we are going to create loops
NGRAPH_CHECK(vParams.rank() > 0, "Invalid size for indices steps"); NGRAPH_CHECK(vParams.rank() > 0, "Invalid size for indices steps");
...@@ -946,7 +941,7 @@ namespace ...@@ -946,7 +941,7 @@ namespace
{ {
if (i == axis) if (i == axis)
{ {
paramsIndices.push_back(IndexHandle(axisIdx)); paramsIndices.push_back(axisIdx);
} }
else else
{ {
...@@ -1022,10 +1017,10 @@ namespace ...@@ -1022,10 +1017,10 @@ namespace
NGRAPH_CHECK(groups > 0, "Invalid number of groups"); NGRAPH_CHECK(groups > 0, "Invalid number of groups");
// create outer group convolution loop // create outer group convolution loop
// for group = 0 to groups // for group = 0 to groups
IndexHandle iv; auto indexType = IndexType::get(rewriter.getContext());
ValueHandle iv(indexType);
ValueHandle lb = intrinsics::constant_index(0); ValueHandle lb = std_constant_index(0);
ValueHandle ub = intrinsics::constant_index(groups); ValueHandle ub = std_constant_index(groups);
auto imagesType = images.getType().cast<MemRefType>(); auto imagesType = images.getType().cast<MemRefType>();
auto filtersType = filters.getType().cast<MemRefType>(); auto filtersType = filters.getType().cast<MemRefType>();
...@@ -1043,13 +1038,13 @@ namespace ...@@ -1043,13 +1038,13 @@ namespace
NGRAPH_CHECK(groupsInFilters || filtersShape[0] % groups == 0, NGRAPH_CHECK(groupsInFilters || filtersShape[0] % groups == 0,
"Filters dim is not divisible by number of groups"); "Filters dim is not divisible by number of groups");
auto channelGroupSize = intrinsics::constant_index(imagesShape[1] / groups); auto channelGroupSize = std_constant_index(imagesShape[1] / groups);
auto filtersGroupSize = intrinsics::constant_index( auto filtersGroupSize =
groupsInFilters ? filtersShape[1] : filtersShape[0] / groups); std_constant_index(groupsInFilters ? filtersShape[1] : filtersShape[0] / groups);
NGRAPH_CHECK(!groupsInFilters || groups == filtersShape[0]); NGRAPH_CHECK(!groupsInFilters || groups == filtersShape[0]);
LoopBuilder::makeAffine(&iv, lb, ub, 1)([&] { makeAffineLoopBuilder(&iv, lb, ub, 1)([&] {
// lower/upper bounds on image channel dim and kernels dim // lower/upper bounds on image channel dim and kernels dim
auto cLb = iv * channelGroupSize; auto cLb = iv * channelGroupSize;
auto cUb = cLb + channelGroupSize; auto cUb = cLb + channelGroupSize;
...@@ -1152,7 +1147,7 @@ namespace ...@@ -1152,7 +1147,7 @@ namespace
castMemRef(inputs, outputs, rewriter, unrankedMemrefTy); castMemRef(inputs, outputs, rewriter, unrankedMemrefTy);
FuncOp callBackFunc = pass.getCallDecl( FuncOp callBackFunc = pass.getCallDecl(
"__mlir_callback_2_inputs", "callback_2_inputs",
{unrankedMemrefTy, unrankedMemrefTy, unrankedMemrefTy, int64Ty, int64Ty}, {unrankedMemrefTy, unrankedMemrefTy, unrankedMemrefTy, int64Ty, int64Ty},
{}, {},
rewriter); rewriter);
...@@ -1245,7 +1240,7 @@ namespace ...@@ -1245,7 +1240,7 @@ namespace
auto int64Ty = rewriter.getIntegerType(64); auto int64Ty = rewriter.getIntegerType(64);
auto unrankedMemrefTy = UnrankedMemRefType::get(elemTy, 0); auto unrankedMemrefTy = UnrankedMemRefType::get(elemTy, 0);
auto callBackFunc = pass.getCallDecl( auto callBackFunc = pass.getCallDecl(
"__mlir_callback_2_inputs", "callback_2_inputs",
{unrankedMemrefTy, unrankedMemrefTy, unrankedMemrefTy, int64Ty, int64Ty}, {unrankedMemrefTy, unrankedMemrefTy, unrankedMemrefTy, int64Ty, int64Ty},
{}, {},
rewriter); rewriter);
...@@ -1297,7 +1292,7 @@ namespace ...@@ -1297,7 +1292,7 @@ namespace
elemTy == biasTy.getElementType(), elemTy == biasTy.getElementType(),
"Types mismatch in GemmOp"); "Types mismatch in GemmOp");
MemRefView vRes(result), vLhs(lhs), vRhs(rhs), vBias(bias); MemRefBoundsCapture vRes(result), vLhs(lhs), vRhs(rhs), vBias(bias);
NGRAPH_CHECK(vLhs.rank() == 2 && vRhs.rank() == 2 && vRes.rank() == 2 && vBias.rank() <= 2, NGRAPH_CHECK(vLhs.rank() == 2 && vRhs.rank() == 2 && vRes.rank() == 2 && vBias.rank() <= 2,
"Gemm operation is only supported for 2D tensors"); "Gemm operation is only supported for 2D tensors");
...@@ -1361,7 +1356,7 @@ namespace ...@@ -1361,7 +1356,7 @@ namespace
auto int64Ty = rewriter.getIntegerType(64); auto int64Ty = rewriter.getIntegerType(64);
auto unrankedMemrefTy = UnrankedMemRefType::get(elemTy, 0); auto unrankedMemrefTy = UnrankedMemRefType::get(elemTy, 0);
auto callBackFunc = pass.getCallDecl("__mlir_callback_3_inputs", auto callBackFunc = pass.getCallDecl("callback_3_inputs",
{unrankedMemrefTy, {unrankedMemrefTy,
unrankedMemrefTy, unrankedMemrefTy,
unrankedMemrefTy, unrankedMemrefTy,
...@@ -1425,7 +1420,7 @@ namespace ...@@ -1425,7 +1420,7 @@ namespace
rewriter.getUnknownLoc(), static_cast<int64_t>(OpType::SOFTMAX), 64); rewriter.getUnknownLoc(), static_cast<int64_t>(OpType::SOFTMAX), 64);
FuncOp callBackFunc = FuncOp callBackFunc =
pass.getCallDecl("__mlir_callback_1_input", pass.getCallDecl("callback_1_input",
{unrankedMemrefTy, unrankedMemrefTy, int64Ty, int64Ty}, {unrankedMemrefTy, unrankedMemrefTy, int64Ty, int64Ty},
{}, {},
rewriter); rewriter);
...@@ -1511,11 +1506,12 @@ namespace ...@@ -1511,11 +1506,12 @@ namespace
auto padBelow = padBelowAttr.getValue(); auto padBelow = padBelowAttr.getValue();
auto padAbove = padBelowAttr.getValue(); auto padAbove = padBelowAttr.getValue();
Type elemTy = images.getType().cast<MemRefType>().getElementType(); Type elemTy = images.getType().cast<MemRefType>().getElementType();
auto indexType = IndexType::get(rewriter.getContext());
// Create views // Create views
MemRefView vRes(result), vImages(images), vFilters(filters); MemRefBoundsCapture vRes(result), vImages(images), vFilters(filters);
// Create indexed Values // Create indexed Values
IndexedValue iRes(result), iImages(images), iFilters(filters); AffineIndexedValue iRes(result), iImages(images), iFilters(filters);
// Bounds on batch size N // Bounds on batch size N
ValueHandle batchLb = vImages.lb(0), batchUb = vImages.ub(0); ValueHandle batchLb = vImages.lb(0), batchUb = vImages.ub(0);
// Bounds on spatial dimensions // Bounds on spatial dimensions
...@@ -1526,9 +1522,8 @@ namespace ...@@ -1526,9 +1522,8 @@ namespace
unsigned spatialRank = vImages.rank() - 2; unsigned spatialRank = vImages.rank() - 2;
// Result spatial indices and bounds // Result spatial indices and bounds
auto resSpatialIndices = makeIndexHandles(spatialRank); auto resSpatialIndices = ValueHandle::makeIndexHandles(spatialRank);
auto resSpatialIndicesPtrs = auto resSpatialIndicesPtrs = makeHandlePointers(resSpatialIndices);
makeHandlePointers(MutableArrayRef<IndexHandle>(resSpatialIndices));
SmallVector<int64_t, 4> resSteps, filtersSteps; SmallVector<int64_t, 4> resSteps, filtersSteps;
SmallVector<int, 4> padBelowIntValues; SmallVector<int, 4> padBelowIntValues;
bool withPadding = false; bool withPadding = false;
...@@ -1610,9 +1605,8 @@ namespace ...@@ -1610,9 +1605,8 @@ namespace
"Results spatial dims mismatches input"); "Results spatial dims mismatches input");
// Filters spatial indices and bounds // Filters spatial indices and bounds
auto filtersSpatialIndices = makeIndexHandles(spatialRank); auto filtersSpatialIndices = ValueHandle::makeIndexHandles(spatialRank);
auto filtersSpatialIndicesPtrs = auto filtersSpatialIndicesPtrs = makeHandlePointers(filtersSpatialIndices);
makeHandlePointers(MutableArrayRef<IndexHandle>(filtersSpatialIndices));
for (auto i = 0; i < spatialRank; i++) for (auto i = 0; i < spatialRank; i++)
{ {
...@@ -1658,23 +1652,22 @@ namespace ...@@ -1658,23 +1652,22 @@ namespace
// Initialize output to zero // Initialize output to zero
{ {
IndexHandle n, k, c; ValueHandle n(indexType), k(indexType), c(indexType);
auto resSpatialIndices = makeIndexHandles(spatialRank); auto resSpatialIndices = ValueHandle::makeIndexHandles(spatialRank);
auto resSpatialIndicesPtrs = auto resSpatialIndicesPtrs = makeHandlePointers(resSpatialIndices);
makeHandlePointers(MutableArrayRef<IndexHandle>(resSpatialIndices));
LoopBuilder::makeAffine(&n, batchLb, batchUb, 1)([&] { makeAffineLoopBuilder(&n, batchLb, batchUb, 1)([&] {
LoopBuilder::makeAffine(&k, numFiltersLb, numFiltersUb, 1)([&] { makeAffineLoopBuilder(&k, numFiltersLb, numFiltersUb, 1)([&] {
AffineLoopNestBuilder( AffineLoopNestBuilder(
resSpatialIndicesPtrs, resSpatialLbs, resSpatialUbs, resSteps)([&] { resSpatialIndicesPtrs, resSpatialLbs, resSpatialUbs, resSteps)([&] {
SmallVector<IndexHandle, 4> resIndices; SmallVector<ValueHandle, 4> resIndices;
// Result indices // Result indices
resIndices.push_back(n); resIndices.push_back(n);
if (groupConvolution && groupsInFilters) if (groupConvolution && groupsInFilters)
{ {
// compute global C_OUT from gID and k // compute global C_OUT from gID and k
// gId * C_OUT (num of filters) + k // gId * C_OUT (num of filters) + k
resIndices.push_back(IndexHandle(ValueHandle(gId) * numFiltersUb + k)); resIndices.push_back(ValueHandle(gId) * numFiltersUb + k);
} }
else else
{ {
...@@ -1689,31 +1682,31 @@ namespace ...@@ -1689,31 +1682,31 @@ namespace
}); });
} }
IndexHandle n, k, c; ValueHandle n(indexType), k(indexType), c(indexType);
// Convolution loop // Convolution loop
LoopBuilder::makeAffine(&n, batchLb, batchUb, 1)([&] { makeAffineLoopBuilder(&n, batchLb, batchUb, 1)([&] {
// Number of filters loop // Number of filters loop
LoopBuilder::makeAffine(&k, numFiltersLb, numFiltersUb, 1)([&] { makeAffineLoopBuilder(&k, numFiltersLb, numFiltersUb, 1)([&] {
// Channels loop // Channels loop
LoopBuilder::makeAffine(&c, numChannelsLb, numChannelsUb, 1)([&] { makeAffineLoopBuilder(&c, numChannelsLb, numChannelsUb, 1)([&] {
// Results loop // Results loop
AffineLoopNestBuilder( AffineLoopNestBuilder(
resSpatialIndicesPtrs, resSpatialLbs, resSpatialUbs, resSteps)([&] { resSpatialIndicesPtrs, resSpatialLbs, resSpatialUbs, resSteps)([&] {
// Compute image start indices // Compute image start indices
SmallVector<IndexHandle, 4> imgStartIndices; SmallVector<ValueHandle, 4> imgStartIndices;
for (auto i = 0; i < spatialRank; i++) for (auto i = 0; i < spatialRank; i++)
{ {
IntegerAttr iAttr = strides[i].cast<IntegerAttr>(); IntegerAttr iAttr = strides[i].cast<IntegerAttr>();
auto stride = intrinsics::constant_index(iAttr.getInt()); auto stride = std_constant_index(iAttr.getInt());
imgStartIndices.push_back(IndexHandle(resSpatialIndices[i] * stride)); imgStartIndices.push_back(resSpatialIndices[i] * stride);
} }
SmallVector<IndexHandle, 4> resIndices; SmallVector<ValueHandle, 4> resIndices;
// Result indices // Result indices
resIndices.push_back(n); resIndices.push_back(n);
if (groupConvolution && groupsInFilters) if (groupConvolution && groupsInFilters)
{ {
// gId * C_OUT (num of filters) + k // gId * C_OUT (num of filters) + k
resIndices.push_back(IndexHandle(ValueHandle(gId) * numFiltersUb + k)); resIndices.push_back(ValueHandle(gId) * numFiltersUb + k);
} }
else else
{ {
...@@ -1727,15 +1720,14 @@ namespace ...@@ -1727,15 +1720,14 @@ namespace
filtersSpatialLbs, filtersSpatialLbs,
filtersSpatialUbs, filtersSpatialUbs,
filtersSteps)([&] { filtersSteps)([&] {
SmallVector<IndexHandle, 4> imgIndices, filtersIndices; SmallVector<ValueHandle, 4> imgIndices, filtersIndices;
// Image indices // Image indices
// Here we compute the virtual start index into the padded image. // Here we compute the virtual start index into the padded image.
imgIndices.push_back(n); imgIndices.push_back(n);
imgIndices.push_back(c); imgIndices.push_back(c);
for (auto i = 0; i < spatialRank; i++) for (auto i = 0; i < spatialRank; i++)
{ {
imgIndices.push_back( imgIndices.push_back(imgStartIndices[i] + filtersSpatialIndices[i]);
IndexHandle(imgStartIndices[i] + filtersSpatialIndices[i]));
} }
// Filter indices // Filter indices
...@@ -1744,14 +1736,14 @@ namespace ...@@ -1744,14 +1736,14 @@ namespace
// index // index
if (groupConvolution && groupsInFilters) if (groupConvolution && groupsInFilters)
{ {
filtersIndices.push_back(IndexHandle(gId)); filtersIndices.push_back(ValueHandle(gId));
} }
filtersIndices.push_back(k); filtersIndices.push_back(k);
// subtract lower bound of channel // subtract lower bound of channel
// if we are doing group convolution this bound will advance based // if we are doing group convolution this bound will advance based
// on the group id. For the filters, it should always start from 0 // on the group id. For the filters, it should always start from 0
filtersIndices.push_back(IndexHandle(c - numChannelsLb)); filtersIndices.push_back(c - numChannelsLb);
filtersIndices.insert(filtersIndices.end(), filtersIndices.insert(filtersIndices.end(),
filtersSpatialIndices.begin(), filtersSpatialIndices.begin(),
filtersSpatialIndices.end()); filtersSpatialIndices.end());
...@@ -1759,7 +1751,7 @@ namespace ...@@ -1759,7 +1751,7 @@ namespace
if (withPadding) if (withPadding)
{ {
// if args : img dims, img lbs, img ubs // if args : img dims, img lbs, img ubs
SmallVector<IndexHandle, 4>::iterator it = imgIndices.begin(); SmallVector<ValueHandle, 4>::iterator it = imgIndices.begin();
std::advance(it, 2); std::advance(it, 2);
SmallVector<Value, 4> affineIfArgs(it, imgIndices.end()); SmallVector<Value, 4> affineIfArgs(it, imgIndices.end());
affineIfArgs.insert( affineIfArgs.insert(
...@@ -1777,14 +1769,14 @@ namespace ...@@ -1777,14 +1769,14 @@ namespace
ScopedContext scope(rewriter, loc); ScopedContext scope(rewriter, loc);
// We must subtract pad below before img load, since the // We must subtract pad below before img load, since the
// physical image is not padded // physical image is not padded
SmallVector<IndexHandle, 4> adjustedImgIndices; SmallVector<ValueHandle, 4> adjustedImgIndices;
adjustedImgIndices.push_back(n); adjustedImgIndices.push_back(n);
adjustedImgIndices.push_back(c); adjustedImgIndices.push_back(c);
for (auto i = 0; i < spatialRank; i++) for (auto i = 0; i < spatialRank; i++)
{ {
adjustedImgIndices.push_back(IndexHandle( adjustedImgIndices.push_back(
imgIndices[2 + i] - imgIndices[2 + i] -
intrinsics::constant_index(padBelowIntValues[i]))); std_constant_index(padBelowIntValues[i]));
} }
iRes(resIndices) = iRes(resIndices) =
iRes(resIndices) + iRes(resIndices) +
...@@ -1821,15 +1813,15 @@ namespace ...@@ -1821,15 +1813,15 @@ namespace
ScopedContext scope(rewriter, loc); ScopedContext scope(rewriter, loc);
// Views // Views
MemRefView vRes(result), vLHS(lhs); MemRefBoundsCapture vRes(result), vLHS(lhs);
// Index Values // Index Values
IndexedValue iRes(result), iLHS(lhs); AffineIndexedValue iRes(result), iLHS(lhs);
// Bounds Index Handles // Bounds Index Handles
auto lbs = vLHS.getLbs(); auto lbs = vLHS.getLbs();
auto ubs = vLHS.getUbs(); auto ubs = vLHS.getUbs();
// Loop induction vars // Loop induction vars
auto ivs = makeIndexHandles(vLHS.rank()); auto ivs = ValueHandle::makeIndexHandles(vLHS.rank());
auto pivs = makeHandlePointers(MutableArrayRef<IndexHandle>(ivs)); auto pivs = makeHandlePointers(ivs);
// Steps // Steps
auto steps = vLHS.getSteps(); auto steps = vLHS.getSteps();
...@@ -1867,15 +1859,15 @@ namespace ...@@ -1867,15 +1859,15 @@ namespace
ScopedContext scope(rewriter, loc); ScopedContext scope(rewriter, loc);
// Views // Views
MemRefView vRes(result), vLHS(lhs), vRHS(rhs); MemRefBoundsCapture vRes(result), vLHS(lhs), vRHS(rhs);
// Index Values // Index Values
IndexedValue iRes(result), iLHS(lhs), iRHS(rhs); AffineIndexedValue iRes(result), iLHS(lhs), iRHS(rhs);
// Bounds Index Handles // Bounds Index Handles
auto lbs = vLHS.getLbs(); auto lbs = vLHS.getLbs();
auto ubs = vLHS.getUbs(); auto ubs = vLHS.getUbs();
// Loop induction vars // Loop induction vars
auto ivs = makeIndexHandles(vLHS.rank()); auto ivs = ValueHandle::makeIndexHandles(vLHS.rank());
auto pivs = makeHandlePointers(MutableArrayRef<IndexHandle>(ivs)); auto pivs = makeHandlePointers(ivs);
// Steps // Steps
auto steps = vLHS.getSteps(); auto steps = vLHS.getSteps();
// element type of the operand // element type of the operand
...@@ -1900,65 +1892,57 @@ namespace ...@@ -1900,65 +1892,57 @@ namespace
iRes(ivs) = iLHS(ivs) / iRHS(ivs); iRes(ivs) = iLHS(ivs) / iRHS(ivs);
} }
// TODO(pthoreho) For all comparision operators, use // TODO(pthoreho) For all comparision operators, use
// edsc::intrinsics::zero_extendi(ValueHandle(iLHS(ivs)) != // zero_extendi(ValueHandle(iLHS(ivs)) !=
// ValueHandle(iRHS(ivs)), IntegerType::get(8, op->getContext())); // ValueHandle(iRHS(ivs)), IntegerType::get(8, op->getContext()));
// instead of edsc::intrinsics::select once `zero_extendi` is // instead of std_select once `zero_extendi` is
// made available in the edsc::intrinsics namescope in MLIR repo. // made available in the edsc::intrinsics namescope in MLIR repo.
else if (isa<NGGreaterOp>(op)) else if (isa<NGGreaterOp>(op))
{ {
iRes(ivs) = iRes(ivs) = std_select(ValueHandle(iLHS(ivs)) > ValueHandle(iRHS(ivs)),
edsc::intrinsics::select(ValueHandle(iLHS(ivs)) > ValueHandle(iRHS(ivs)), createOneConstant(elemTy),
createOneConstant(elemTy), createZeroConstant(elemTy));
createZeroConstant(elemTy));
} }
else if (isa<NGLessOp>(op)) else if (isa<NGLessOp>(op))
{ {
iRes(ivs) = iRes(ivs) = std_select(ValueHandle(iLHS(ivs)) < ValueHandle(iRHS(ivs)),
edsc::intrinsics::select(ValueHandle(iLHS(ivs)) < ValueHandle(iRHS(ivs)), createOneConstant(elemTy),
createOneConstant(elemTy), createZeroConstant(elemTy));
createZeroConstant(elemTy));
} }
else if (isa<NGGreaterEqOp>(op)) else if (isa<NGGreaterEqOp>(op))
{ {
iRes(ivs) = iRes(ivs) = std_select(ValueHandle(iLHS(ivs)) >= ValueHandle(iRHS(ivs)),
edsc::intrinsics::select(ValueHandle(iLHS(ivs)) >= ValueHandle(iRHS(ivs)), createOneConstant(elemTy),
createOneConstant(elemTy), createZeroConstant(elemTy));
createZeroConstant(elemTy));
} }
else if (isa<NGLessEqOp>(op)) else if (isa<NGLessEqOp>(op))
{ {
iRes(ivs) = iRes(ivs) = std_select(ValueHandle(iLHS(ivs)) <= ValueHandle(iRHS(ivs)),
edsc::intrinsics::select(ValueHandle(iLHS(ivs)) <= ValueHandle(iRHS(ivs)), createOneConstant(elemTy),
createOneConstant(elemTy), createZeroConstant(elemTy));
createZeroConstant(elemTy));
} }
else if (isa<NGEqOp>(op)) else if (isa<NGEqOp>(op))
{ {
iRes(ivs) = iRes(ivs) = std_select(ValueHandle(iLHS(ivs)) == ValueHandle(iRHS(ivs)),
edsc::intrinsics::select(ValueHandle(iLHS(ivs)) == ValueHandle(iRHS(ivs)), createOneConstant(elemTy),
createOneConstant(elemTy), createZeroConstant(elemTy));
createZeroConstant(elemTy));
} }
else if (isa<NGNotEqOp>(op)) else if (isa<NGNotEqOp>(op))
{ {
iRes(ivs) = iRes(ivs) = std_select(ValueHandle(iLHS(ivs)) != ValueHandle(iRHS(ivs)),
edsc::intrinsics::select(ValueHandle(iLHS(ivs)) != ValueHandle(iRHS(ivs)), createOneConstant(elemTy),
createOneConstant(elemTy), createZeroConstant(elemTy));
createZeroConstant(elemTy));
} }
else if (isa<NGMaxOp>(op)) else if (isa<NGMaxOp>(op))
{ {
iRes(ivs) = iRes(ivs) = std_select(ValueHandle(iLHS(ivs)) > ValueHandle(iRHS(ivs)),
edsc::intrinsics::select(ValueHandle(iLHS(ivs)) > ValueHandle(iRHS(ivs)), ValueHandle(iLHS(ivs)),
ValueHandle(iLHS(ivs)), ValueHandle(iRHS(ivs)));
ValueHandle(iRHS(ivs)));
} }
else if (isa<NGMinOp>(op)) else if (isa<NGMinOp>(op))
{ {
iRes(ivs) = iRes(ivs) = std_select(ValueHandle(iLHS(ivs)) < ValueHandle(iRHS(ivs)),
edsc::intrinsics::select(ValueHandle(iLHS(ivs)) < ValueHandle(iRHS(ivs)), ValueHandle(iLHS(ivs)),
ValueHandle(iLHS(ivs)), ValueHandle(iRHS(ivs)));
ValueHandle(iRHS(ivs)));
} }
else else
{ {
...@@ -1995,10 +1979,10 @@ namespace ...@@ -1995,10 +1979,10 @@ namespace
Value result = pass.buildOutputDefs(op, rewriter)[0]; Value result = pass.buildOutputDefs(op, rewriter)[0];
// Views // Views
MemRefView vRes(result), vArg(arg); MemRefBoundsCapture vRes(result), vArg(arg);
// Index Values // Index Values
StdIndexedValue iRes(result), stdArg(arg); StdIndexedValue iRes(result), stdArg(arg);
IndexedValue affineArg(arg); AffineIndexedValue affineArg(arg);
// Bounds Index Handles // Bounds Index Handles
auto resLbs = vRes.getLbs(); auto resLbs = vRes.getLbs();
auto resUbs = vRes.getUbs(); auto resUbs = vRes.getUbs();
...@@ -2008,8 +1992,8 @@ namespace ...@@ -2008,8 +1992,8 @@ namespace
Type resTy = result.getType().cast<MemRefType>().getElementType(); Type resTy = result.getType().cast<MemRefType>().getElementType();
// Generate loop nest that initializes result to lower bound of the axis to be reduced. // Generate loop nest that initializes result to lower bound of the axis to be reduced.
{ {
auto ivs = makeIndexHandles(vRes.rank()); auto ivs = ValueHandle::makeIndexHandles(vRes.rank());
auto pivs = makeHandlePointers(MutableArrayRef<IndexHandle>(ivs)); auto pivs = makeHandlePointers(ivs);
auto steps = vRes.getSteps(); auto steps = vRes.getSteps();
auto initVal = vArg.lb(axis); auto initVal = vArg.lb(axis);
AffineLoopNestBuilder(pivs, resLbs, resUbs, steps)( AffineLoopNestBuilder(pivs, resLbs, resUbs, steps)(
...@@ -2018,10 +2002,10 @@ namespace ...@@ -2018,10 +2002,10 @@ namespace
// Generate loop nest that computes the actual index reduction. // Generate loop nest that computes the actual index reduction.
{ {
auto allIVs = makeIndexHandles(vArg.rank()); auto allIVs = ValueHandle::makeIndexHandles(vArg.rank());
auto pAllIVs = makeHandlePointers(MutableArrayRef<IndexHandle>(allIVs)); auto pAllIVs = makeHandlePointers(allIVs);
auto steps = vArg.getSteps(); auto steps = vArg.getSteps();
SmallVector<IndexHandle, 8> nonRedIVs; SmallVector<ValueHandle, 8> nonRedIVs;
Type resTy = result.getType().cast<MemRefType>().getElementType(); Type resTy = result.getType().cast<MemRefType>().getElementType();
NGRAPH_CHECK(resTy.isa<IntegerType>(), NGRAPH_CHECK(resTy.isa<IntegerType>(),
...@@ -2049,10 +2033,8 @@ namespace ...@@ -2049,10 +2033,8 @@ namespace
// Select the min/max value and cast it back to integer type before storing it. // Select the min/max value and cast it back to integer type before storing it.
ValueHandle newRedIdx = ValueHandle newRedIdx =
std::is_same<RedOp, NGArgMinRedOp>() std::is_same<RedOp, NGArgMinRedOp>()
? edsc::intrinsics::select( ? std_select(affineArg(allIVs) < stdArg(tempIVs), allIVs[axis], currRedIdx)
affineArg(allIVs) < stdArg(tempIVs), allIVs[axis], currRedIdx) : std_select(stdArg(tempIVs) < affineArg(allIVs), allIVs[axis], currRedIdx);
: edsc::intrinsics::select(
stdArg(tempIVs) < affineArg(allIVs), allIVs[axis], currRedIdx);
iRes(nonRedIVs) = ValueHandle::create<IndexCastOp>(newRedIdx, resTy); iRes(nonRedIVs) = ValueHandle::create<IndexCastOp>(newRedIdx, resTy);
}); });
...@@ -2123,7 +2105,7 @@ namespace ...@@ -2123,7 +2105,7 @@ namespace
castMemRef(inputs, outputs, rewriter, unrankedMemrefTy); castMemRef(inputs, outputs, rewriter, unrankedMemrefTy);
FuncOp callBackFunc = FuncOp callBackFunc =
pass.getCallDecl("__mlir_callback_1_input", pass.getCallDecl("callback_1_input",
{unrankedMemrefTy, unrankedMemrefTy, int64Ty, int64Ty}, {unrankedMemrefTy, unrankedMemrefTy, int64Ty, int64Ty},
{}, {},
rewriter); rewriter);
...@@ -2168,11 +2150,11 @@ namespace ...@@ -2168,11 +2150,11 @@ namespace
{ {
if (floatTy.isF32()) if (floatTy.isF32())
{ {
return intrinsics::constant_float(llvm::APFloat(0.0f), floatTy); return std_constant_float(llvm::APFloat(0.0f), floatTy);
} }
else if (floatTy.isF64()) else if (floatTy.isF64())
{ {
return intrinsics::constant_float(llvm::APFloat(0.0), floatTy); return std_constant_float(llvm::APFloat(0.0), floatTy);
} }
else else
{ {
...@@ -2181,7 +2163,7 @@ namespace ...@@ -2181,7 +2163,7 @@ namespace
} }
else if (auto intTy = type.dyn_cast<IntegerType>()) else if (auto intTy = type.dyn_cast<IntegerType>())
{ {
return intrinsics::constant_int(0, intTy.getWidth()); return std_constant_int(0, intTy.getWidth());
} }
NGRAPH_UNREACHABLE("Unsupported type"); NGRAPH_UNREACHABLE("Unsupported type");
} }
...@@ -2192,11 +2174,11 @@ namespace ...@@ -2192,11 +2174,11 @@ namespace
{ {
if (floatTy.isF32()) if (floatTy.isF32())
{ {
return intrinsics::constant_float(llvm::APFloat(1.0f), floatTy); return std_constant_float(llvm::APFloat(1.0f), floatTy);
} }
else if (floatTy.isF64()) else if (floatTy.isF64())
{ {
return intrinsics::constant_float(llvm::APFloat(1.0f), floatTy); return std_constant_float(llvm::APFloat(1.0f), floatTy);
} }
else else
{ {
...@@ -2205,7 +2187,7 @@ namespace ...@@ -2205,7 +2187,7 @@ namespace
} }
else if (auto intTy = type.dyn_cast<IntegerType>()) else if (auto intTy = type.dyn_cast<IntegerType>())
{ {
return intrinsics::constant_int(1, intTy.getWidth()); return std_constant_int(1, intTy.getWidth());
} }
NGRAPH_UNREACHABLE("Unsupported type"); NGRAPH_UNREACHABLE("Unsupported type");
} }
......
...@@ -23,9 +23,7 @@ ...@@ -23,9 +23,7 @@
#include "contrib/mlir/core/ngraph_dialect/type.hpp" #include "contrib/mlir/core/ngraph_dialect/type.hpp"
#include <llvm/IR/Module.h> #include <llvm/IR/Module.h>
#include <mlir/EDSC/Builders.h> #include <mlir/Dialect/AffineOps/EDSC/Builders.h>
#include <mlir/EDSC/Helpers.h>
#include <mlir/EDSC/Intrinsics.h>
#include <mlir/IR/IntegerSet.h> #include <mlir/IR/IntegerSet.h>
#include <mlir/IR/MLIRContext.h> #include <mlir/IR/MLIRContext.h>
#include <mlir/IR/StandardTypes.h> #include <mlir/IR/StandardTypes.h>
......
...@@ -719,7 +719,7 @@ static void __mlir_cblas_sgemm_with_bias(StaticMemRef* memRefmatA, ...@@ -719,7 +719,7 @@ static void __mlir_cblas_sgemm_with_bias(StaticMemRef* memRefmatA,
} }
} }
extern "C" void __mlir_callback_1_input(void* input, void* output, size_t index, OpType type) extern "C" void _mlir_ciface_callback_1_input(void* input, void* output, size_t index, OpType type)
{ {
auto unrankedMemRefInput = reinterpret_cast<UnrankedMemRef*>(input); auto unrankedMemRefInput = reinterpret_cast<UnrankedMemRef*>(input);
auto unrankedMemRefOutput = reinterpret_cast<UnrankedMemRef*>(output); auto unrankedMemRefOutput = reinterpret_cast<UnrankedMemRef*>(output);
...@@ -752,8 +752,8 @@ extern "C" void __mlir_callback_1_input(void* input, void* output, size_t index, ...@@ -752,8 +752,8 @@ extern "C" void __mlir_callback_1_input(void* input, void* output, size_t index,
} }
} }
extern "C" void extern "C" void _mlir_ciface_callback_2_inputs(
__mlir_callback_2_inputs(void* input0, void* input1, void* output, size_t index, OpType type) void* input0, void* input1, void* output, size_t index, OpType type)
{ {
auto unrankedMemRefInput0 = reinterpret_cast<UnrankedMemRef*>(input0); auto unrankedMemRefInput0 = reinterpret_cast<UnrankedMemRef*>(input0);
auto unrankedMemRefInput1 = reinterpret_cast<UnrankedMemRef*>(input1); auto unrankedMemRefInput1 = reinterpret_cast<UnrankedMemRef*>(input1);
...@@ -780,7 +780,7 @@ extern "C" void ...@@ -780,7 +780,7 @@ extern "C" void
} }
} }
extern "C" void __mlir_callback_3_inputs( extern "C" void _mlir_ciface_callback_3_inputs(
void* input0, void* input1, void* input2, void* output, size_t index, OpType type) void* input0, void* input1, void* input2, void* output, size_t index, OpType type)
{ {
auto unrankedMemRefInput0 = reinterpret_cast<UnrankedMemRef*>(input0); auto unrankedMemRefInput0 = reinterpret_cast<UnrankedMemRef*>(input0);
......
...@@ -83,7 +83,7 @@ void MLIRCPURuntime::bindArguments(const std::vector<MemRefArg>& args) ...@@ -83,7 +83,7 @@ void MLIRCPURuntime::bindArguments(const std::vector<MemRefArg>& args)
{ {
NGRAPH_CHECK(m_module, "MLIR module is not ready."); NGRAPH_CHECK(m_module, "MLIR module is not ready.");
auto func = m_module->lookupSymbol<mlir::LLVM::LLVMFuncOp>("main"); auto func = m_module->lookupSymbol<mlir::LLVM::LLVMFuncOp>("_mlir_ciface_main");
NGRAPH_CHECK(func && !func.getBlocks().empty(), "Function not found"); NGRAPH_CHECK(func && !func.getBlocks().empty(), "Function not found");
// Set external arguments // Set external arguments
...@@ -127,14 +127,15 @@ void MLIRCPURuntime::execute() ...@@ -127,14 +127,15 @@ void MLIRCPURuntime::execute()
// uniformity reasons, it takes a list of type-erased pointers to arguments. // uniformity reasons, it takes a list of type-erased pointers to arguments.
// Please, note that 'invoke' method is overloaded with a parameter pack version. // Please, note that 'invoke' method is overloaded with a parameter pack version.
// Make sure the MutableArrayRef version is invoked. // Make sure the MutableArrayRef version is invoked.
auto invocationResult = m_engine->invoke("main", llvm::MutableArrayRef<void*>(m_invokeArgs)); auto invocationResult =
m_engine->invoke("_mlir_ciface_main", llvm::MutableArrayRef<void*>(m_invokeArgs));
if (clDumpObjectFile) if (clDumpObjectFile)
{ {
m_engine->dumpToObjectFile(clObjectFilename.empty() ? "jitted_mlir.o" m_engine->dumpToObjectFile(clObjectFilename.empty() ? "jitted_mlir.o"
: clObjectFilename.getValue()); : clObjectFilename.getValue());
} }
NGRAPH_CHECK(!invocationResult, "JIT invocation of 'main' failed\n"); NGRAPH_CHECK(!invocationResult, "JIT invocation of '_mlir_ciface_main' failed\n");
} }
void MLIRCPURuntime::cleanup() void MLIRCPURuntime::cleanup()
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
set(LIBS set(LIBS
mlir_backend mlir_backend
MLIROptMain MLIROptLib
MLIRPass MLIRPass
MLIRParser MLIRParser
LLVMSupport LLVMSupport
......
...@@ -21,10 +21,21 @@ ...@@ -21,10 +21,21 @@
#include "contrib/mlir/core/ngraph_dialect/dialect.hpp" #include "contrib/mlir/core/ngraph_dialect/dialect.hpp"
#include <llvm/Support/CommandLine.h> #include <mlir/Dialect/AffineOps/AffineOps.h>
#include <llvm/Support/Debug.h> #include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/Dialect/LoopOps/LoopOps.h>
#include <mlir/Dialect/StandardOps/Ops.h>
#include <mlir/Dialect/VectorOps/VectorOps.h>
#include <mlir/IR/Dialect.h> #include <mlir/IR/Dialect.h>
#include <mlir/IR/MLIRContext.h> #include <mlir/IR/MLIRContext.h>
#include <mlir/Pass/Pass.h>
#include <mlir/Transforms/LocationSnapshot.h>
#include <mlir/Transforms/Passes.h>
#include <llvm/Support/CommandLine.h>
#include <llvm/Support/Debug.h>
using namespace mlir;
static llvm::cl::opt<bool> clPrintIRAfterAll( static llvm::cl::opt<bool> clPrintIRAfterAll(
"ngraph-print-ir-after-all", "ngraph-print-ir-after-all",
...@@ -35,15 +46,47 @@ static llvm::cl::opt<bool> clPrintIRAfterAll( ...@@ -35,15 +46,47 @@ static llvm::cl::opt<bool> clPrintIRAfterAll(
void ngraph::runtime::ngmlir::initializeNGraphMLIR() void ngraph::runtime::ngmlir::initializeNGraphMLIR()
{ {
// Initialize a dialect only once. // Initialize MLIR dialects and passes only once.
// We currently have no way to query if a dialect is previously static bool init_once = []() {
// registered. So using a global flag instead. // In-tree Dialects.
static bool init = false; registerDialect<AffineOpsDialect>();
if (!init) registerDialect<LLVM::LLVMDialect>();
{ registerDialect<loop::LoopOpsDialect>();
mlir::registerDialect<mlir::NGraphOpsDialect>(); registerDialect<StandardOpsDialect>();
init = true; registerDialect<vector::VectorOpsDialect>();
}
// nGraph dialects.
registerDialect<mlir::NGraphOpsDialect>();
// In-tree passes.
// No-op to avoid DCE on the following pass initializations.
if (std::getenv("bar") != (char*)-1)
return false;
createCanonicalizerPass();
createCSEPass();
createVectorizePass({});
createLoopUnrollPass();
createLoopUnrollAndJamPass();
createSimplifyAffineStructuresPass();
createLoopFusionPass();
createLoopInvariantCodeMotionPass();
createAffineLoopInvariantCodeMotionPass();
createPipelineDataTransferPass();
createLowerAffinePass();
createLoopTilingPass(0);
createLoopCoalescingPass();
createAffineDataCopyGenerationPass(0, 0);
createMemRefDataFlowOptPass();
createStripDebugInfoPass();
createPrintOpStatsPass();
createInlinerPass();
createSymbolDCEPass();
createLocationSnapshotPass({});
return true;
}();
(void)init_once;
} }
void ngraph::runtime::ngmlir::dumpMlirModule(const std::string msg, mlir::ModuleOp module) void ngraph::runtime::ngmlir::dumpMlirModule(const std::string msg, mlir::ModuleOp module)
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64 // CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: %0 = memref_cast %arg0 : memref<2x3xf32> to memref<*xf32> // CHECK: %0 = memref_cast %arg0 : memref<2x3xf32> to memref<*xf32>
// CHECK: %1 = memref_cast %arg2 : memref<2x3xf32> to memref<*xf32> // CHECK: %1 = memref_cast %arg2 : memref<2x3xf32> to memref<*xf32>
// CHECK: call @__mlir_callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> () // CHECK: call @callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_softmax(%arg0: !ng.tensor<2x3xf32>, %arg1: !ng.tensor<1x!ng.i64>) -> !ng.tensor<2x3xf32> { func @simple_softmax(%arg0: !ng.tensor<2x3xf32>, %arg1: !ng.tensor<1x!ng.i64>) -> !ng.tensor<2x3xf32> {
%0 = "ng.softmax"(%arg0) {axes = [0]} : (!ng.tensor<2x3xf32>) -> !ng.tensor<2x3xf32> %0 = "ng.softmax"(%arg0) {axes = [0]} : (!ng.tensor<2x3xf32>) -> !ng.tensor<2x3xf32>
"ng.return"(%0) : (!ng.tensor<2x3xf32>) -> () "ng.return"(%0) : (!ng.tensor<2x3xf32>) -> ()
...@@ -26,7 +26,7 @@ func @simple_softmax(%arg0: !ng.tensor<2x3xf32>, %arg1: !ng.tensor<1x!ng.i64>) - ...@@ -26,7 +26,7 @@ func @simple_softmax(%arg0: !ng.tensor<2x3xf32>, %arg1: !ng.tensor<1x!ng.i64>) -
// CHECK: %1 = memref_cast %arg1 : memref<6x4xf32> to memref<*xf32> // CHECK: %1 = memref_cast %arg1 : memref<6x4xf32> to memref<*xf32>
// CHECK: %2 = memref_cast %arg2 : memref<3x4xf32> to memref<*xf32> // CHECK: %2 = memref_cast %arg2 : memref<3x4xf32> to memref<*xf32>
// CHECK: %3 = memref_cast %arg3 : memref<3x4xf32> to memref<*xf32> // CHECK: %3 = memref_cast %arg3 : memref<3x4xf32> to memref<*xf32>
// CHECK: call @__mlir_callback_3_inputs(%0, %1, %2, %3, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, memref<*xf32>, i64, i64) -> () // CHECK: call @callback_3_inputs(%0, %1, %2, %3, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_gemm(%arg0: !ng.tensor<3x6xf32>, %arg1: !ng.tensor<6x4xf32>, %arg2: !ng.tensor<3x4xf32>) -> !ng.tensor<3x4xf32> { func @simple_gemm(%arg0: !ng.tensor<3x6xf32>, %arg1: !ng.tensor<6x4xf32>, %arg2: !ng.tensor<3x4xf32>) -> !ng.tensor<3x4xf32> {
%0 = "ng.gemm"(%arg0, %arg1, %arg2) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = false, transB = false} : (!ng.tensor<3x6xf32>, !ng.tensor<6x4xf32>, !ng.tensor<3x4xf32>) -> !ng.tensor<3x4xf32> %0 = "ng.gemm"(%arg0, %arg1, %arg2) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = false, transB = false} : (!ng.tensor<3x6xf32>, !ng.tensor<6x4xf32>, !ng.tensor<3x4xf32>) -> !ng.tensor<3x4xf32>
"ng.return"(%0) : (!ng.tensor<3x4xf32>) -> () "ng.return"(%0) : (!ng.tensor<3x4xf32>) -> ()
...@@ -41,7 +41,7 @@ func @simple_gemm(%arg0: !ng.tensor<3x6xf32>, %arg1: !ng.tensor<6x4xf32>, %arg2: ...@@ -41,7 +41,7 @@ func @simple_gemm(%arg0: !ng.tensor<3x6xf32>, %arg1: !ng.tensor<6x4xf32>, %arg2:
// CHECK: %0 = memref_cast %arg0 : memref<3x2xf32> to memref<*xf32> // CHECK: %0 = memref_cast %arg0 : memref<3x2xf32> to memref<*xf32>
// CHECK: %1 = memref_cast %arg1 : memref<2x3xf32> to memref<*xf32> // CHECK: %1 = memref_cast %arg1 : memref<2x3xf32> to memref<*xf32>
// CHECK: %2 = memref_cast %arg2 : memref<2x2xf32> to memref<*xf32> // CHECK: %2 = memref_cast %arg2 : memref<2x2xf32> to memref<*xf32>
// CHECK: call @__mlir_callback_2_inputs(%0, %1, %2, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, i64, i64) -> () // CHECK: call @callback_2_inputs(%0, %1, %2, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_matmul(%arg0: !ng.tensor<3x2xf32>, %arg1: !ng.tensor<2x3xf32>) -> !ng.tensor<2x2xf32> { func @simple_matmul(%arg0: !ng.tensor<3x2xf32>, %arg1: !ng.tensor<2x3xf32>) -> !ng.tensor<2x2xf32> {
%0 = "ng.matmul"(%arg0, %arg1) {transposeA = true, transposeB = true} : (!ng.tensor<3x2xf32>, !ng.tensor<2x3xf32>) -> !ng.tensor<2x2xf32> %0 = "ng.matmul"(%arg0, %arg1) {transposeA = true, transposeB = true} : (!ng.tensor<3x2xf32>, !ng.tensor<2x3xf32>) -> !ng.tensor<2x2xf32>
"ng.return"(%0) : (!ng.tensor<2x2xf32>) -> () "ng.return"(%0) : (!ng.tensor<2x2xf32>) -> ()
...@@ -55,7 +55,7 @@ func @simple_matmul(%arg0: !ng.tensor<3x2xf32>, %arg1: !ng.tensor<2x3xf32>) -> ! ...@@ -55,7 +55,7 @@ func @simple_matmul(%arg0: !ng.tensor<3x2xf32>, %arg1: !ng.tensor<2x3xf32>) -> !
// CHECK: %1 = memref_cast %arg1 : memref<2x1x3x3xf32> to memref<*xf32> // CHECK: %1 = memref_cast %arg1 : memref<2x1x3x3xf32> to memref<*xf32>
// CHECK: %[[C1:.*]] = constant 0 : i64 // CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64 // CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: call @__mlir_callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> () // CHECK: call @callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_avgpool(%arg0: !ng.tensor<2x1x3x3xf32>) -> !ng.tensor<2x1x3x3xf32> { func @simple_avgpool(%arg0: !ng.tensor<2x1x3x3xf32>) -> !ng.tensor<2x1x3x3xf32> {
%0 = "ng.avgPool"(%arg0) {includePadding = true, padAbove = [1, 1], padBelow = [0, 0], windowMovementStrides = [1, 1], windowShape = [2, 2]} : (!ng.tensor<2x1x3x3xf32>) -> !ng.tensor<2x1x3x3xf32> %0 = "ng.avgPool"(%arg0) {includePadding = true, padAbove = [1, 1], padBelow = [0, 0], windowMovementStrides = [1, 1], windowShape = [2, 2]} : (!ng.tensor<2x1x3x3xf32>) -> !ng.tensor<2x1x3x3xf32>
"ng.return"(%0) : (!ng.tensor<2x1x3x3xf32>) -> () "ng.return"(%0) : (!ng.tensor<2x1x3x3xf32>) -> ()
...@@ -69,7 +69,7 @@ func @simple_avgpool(%arg0: !ng.tensor<2x1x3x3xf32>) -> !ng.tensor<2x1x3x3xf32> ...@@ -69,7 +69,7 @@ func @simple_avgpool(%arg0: !ng.tensor<2x1x3x3xf32>) -> !ng.tensor<2x1x3x3xf32>
// CHECK: %1 = memref_cast %arg1 : memref<2x2x3x3xf32> to memref<*xf32> // CHECK: %1 = memref_cast %arg1 : memref<2x2x3x3xf32> to memref<*xf32>
// CHECK: %[[C1:.*]] = constant 0 : i64 // CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64 // CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: call @__mlir_callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> () // CHECK: call @callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_avgpoolbackprop(%arg0: !ng.tensor<2x2x2x2xf32>) -> !ng.tensor<2x2x3x3xf32> { func @simple_avgpoolbackprop(%arg0: !ng.tensor<2x2x2x2xf32>) -> !ng.tensor<2x2x3x3xf32> {
%0 = "ng.avgPoolBackprop"(%arg0) {forwardArgShape = [2, 2, 3, 3], includePadding = false, padAbove = [0, 0], padBelow = [0, 0], windowMovementStrides = [1, 1], windowShape = [2, 2]} : (!ng.tensor<2x2x2x2xf32>) -> !ng.tensor<2x2x3x3xf32> %0 = "ng.avgPoolBackprop"(%arg0) {forwardArgShape = [2, 2, 3, 3], includePadding = false, padAbove = [0, 0], padBelow = [0, 0], windowMovementStrides = [1, 1], windowShape = [2, 2]} : (!ng.tensor<2x2x2x2xf32>) -> !ng.tensor<2x2x3x3xf32>
"ng.return"(%0) : (!ng.tensor<2x2x3x3xf32>) -> () "ng.return"(%0) : (!ng.tensor<2x2x3x3xf32>) -> ()
...@@ -83,7 +83,7 @@ func @simple_avgpoolbackprop(%arg0: !ng.tensor<2x2x2x2xf32>) -> !ng.tensor<2x2x3 ...@@ -83,7 +83,7 @@ func @simple_avgpoolbackprop(%arg0: !ng.tensor<2x2x2x2xf32>) -> !ng.tensor<2x2x3
// CHECK: %1 = memref_cast %arg1 : memref<64x3x9x6x5xf32> to memref<*xf32> // CHECK: %1 = memref_cast %arg1 : memref<64x3x9x6x5xf32> to memref<*xf32>
// CHECK: %[[C1:.*]] = constant 0 : i64 // CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64 // CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: call @__mlir_callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> () // CHECK: call @callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_maxpool(%arg0: !ng.tensor<64x3x7x8x10xf32>) -> !ng.tensor<64x3x9x6x5xf32> { func @simple_maxpool(%arg0: !ng.tensor<64x3x7x8x10xf32>) -> !ng.tensor<64x3x9x6x5xf32> {
%0 = "ng.maxPool"(%arg0) {padAbove = [6, 4, 5], padBelow = [5, 6, 4], windowMovementStrides = [2, 3, 4], windowShape = [2, 3, 2]} : (!ng.tensor<64x3x7x8x10xf32>) -> !ng.tensor<64x3x9x6x5xf32> %0 = "ng.maxPool"(%arg0) {padAbove = [6, 4, 5], padBelow = [5, 6, 4], windowMovementStrides = [2, 3, 4], windowShape = [2, 3, 2]} : (!ng.tensor<64x3x7x8x10xf32>) -> !ng.tensor<64x3x9x6x5xf32>
"ng.return"(%0) : (!ng.tensor<64x3x9x6x5xf32>) -> () "ng.return"(%0) : (!ng.tensor<64x3x9x6x5xf32>) -> ()
...@@ -98,7 +98,7 @@ func @simple_maxpool(%arg0: !ng.tensor<64x3x7x8x10xf32>) -> !ng.tensor<64x3x9x6x ...@@ -98,7 +98,7 @@ func @simple_maxpool(%arg0: !ng.tensor<64x3x7x8x10xf32>) -> !ng.tensor<64x3x9x6x
// CHECK: %2 = memref_cast %arg2 : memref<2x2x5x5xf32> to memref<*xf32> // CHECK: %2 = memref_cast %arg2 : memref<2x2x5x5xf32> to memref<*xf32>
// CHECK: %[[C1:.*]] = constant 0 : i64 // CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64 // CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: call @__mlir_callback_2_inputs(%0, %1, %2, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, i64, i64) -> () // CHECK: call @callback_2_inputs(%0, %1, %2, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_maxpoolbackprop(%arg0: !ng.tensor<2x2x5x5xf32>, %arg1: !ng.tensor<2x2x4x3xf32>) -> !ng.tensor<2x2x5x5xf32> { func @simple_maxpoolbackprop(%arg0: !ng.tensor<2x2x5x5xf32>, %arg1: !ng.tensor<2x2x4x3xf32>) -> !ng.tensor<2x2x5x5xf32> {
%0 = "ng.maxPoolBackprop"(%arg0, %arg1) {padAbove = [0, 0], padBelow = [0, 0], windowMovementStrides = [1, 1], windowShape = [2, 3]} : (!ng.tensor<2x2x5x5xf32>, !ng.tensor<2x2x4x3xf32>) -> !ng.tensor<2x2x5x5xf32> %0 = "ng.maxPoolBackprop"(%arg0, %arg1) {padAbove = [0, 0], padBelow = [0, 0], windowMovementStrides = [1, 1], windowShape = [2, 3]} : (!ng.tensor<2x2x5x5xf32>, !ng.tensor<2x2x4x3xf32>) -> !ng.tensor<2x2x5x5xf32>
"ng.return"(%0) : (!ng.tensor<2x2x5x5xf32>) -> () "ng.return"(%0) : (!ng.tensor<2x2x5x5xf32>) -> ()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment