Unverified Commit e5258c0a authored by Diego Caballero's avatar Diego Caballero Committed by GitHub

[MLIR] Update MLIR repo (#4335)

* [MLIR] Update MLIR repo

* Nagy's fix

* Changes related to mlir-opt

* Update MLIR commit

* Update MLIR commit and callbacks.

* Disable 'noalias' attribute.

It will be re-introduced in a follow-up commit.

* Remove '__mlir' prefix in callback test

* Address feedback

* Fix EDSC includes

* Move MLIR repo forward

* Update type converter code

* Address feedback
Co-authored-by: 's avatarAmy Zhuang <amyzhuang97@gmail.com>
parent fdd8db66
......@@ -19,7 +19,7 @@ include(ExternalProject)
set(MLIR_LLVM_REPO_URL https://github.com/llvm/llvm-project.git)
# Change these commit IDs to move to latest stable versions
set(MLIR_LLVM_COMMIT_ID 96400ae)
set(MLIR_LLVM_COMMIT_ID 376c6853)
# MLIR environment variables. Some of them are used by LIT tool.
......
......@@ -26,7 +26,6 @@
#include <llvm/ADT/DenseSet.h>
#include <map>
#include <mlir/EDSC/Builders.h>
#include <mlir/EDSC/Helpers.h>
#include <mlir/EDSC/Intrinsics.h>
#include <mlir/IR/AffineExpr.h>
#include <mlir/IR/IntegerSet.h>
......
......@@ -194,7 +194,8 @@ void MLIRCPUBackend::lowerNgDialect()
void MLIRCPUBackend::lowerStandardDialect()
{
mlir::PassManager pm(&m_context);
pm.addPass(mlir::createLowerToLLVMPass());
pm.addPass(mlir::createLowerToLLVMPass(
/*useAlloca=*/false, /*useBarePtrCallConv=*/false, /*emitCWrappers=*/true));
// Apply any generic pass manager command line options.
mlir::applyPassManagerCLOptions(pm);
......
......@@ -28,9 +28,9 @@
#include <llvm/ADT/DenseSet.h>
#include <llvm/Support/Debug.h>
#include <mlir/EDSC/Builders.h>
#include <mlir/EDSC/Helpers.h>
#include <mlir/EDSC/Intrinsics.h>
#include <mlir/Dialect/AffineOps/EDSC/Builders.h>
#include <mlir/Dialect/AffineOps/EDSC/Intrinsics.h>
#include <mlir/Dialect/StandardOps/EDSC/Intrinsics.h>
#include <mlir/IR/AffineExpr.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/IntegerSet.h>
......@@ -51,11 +51,10 @@ namespace
{
using namespace mlir;
using namespace mlir::edsc;
using namespace mlir::edsc::intrinsics;
using namespace mlir::edsc::op;
using namespace ngraph::runtime;
using namespace ngraph::runtime::ngmlir;
// Index notation to generate standard (i.e., non-affine) loads and stores.
using StdIndexedValue = TemplatedIndexedValue<intrinsics::std_load, intrinsics::std_store>;
class DialectLoweringPass;
......@@ -215,9 +214,37 @@ namespace
NGraphTypeConverter()
: TypeConverter()
{
}
// TODO(dcaballe): split this into independent conversion patterns when there is a
// way to check if a type is valid in Std dialect.
addConversion([this](Type type) -> Type {
if (auto tensorType = type.dyn_cast<NGTensorType>())
{
// Convert NGTensorType to Std MemRefType directly instead of going to Std
// TensorType. This may change in the future.
return MemRefType::get(tensorType.getShape(),
convertType(tensorType.getElementType()),
{/* no map used */},
0);
}
if (auto floatType = type.dyn_cast<NGFloatType>())
{
// Float types are already std type.
return floatType;
}
if (auto intType = type.dyn_cast<NGIntegerType>())
{
return mlir::IntegerType::get(intType.getWidth(), intType.getContext());
}
if (auto boolType = type.dyn_cast<NGBoolType>())
{
return mlir::IntegerType::get(1 /* width */, boolType.getContext());
}
Type convertType(Type t) override;
// Do not assert/NGRAPH_CHECK here. Type convertion infra expects `convertType` to
// return the input type if the type is not supported.
return type;
});
}
};
/// Dialect Lowering Pass to affine ops
......@@ -317,7 +344,8 @@ namespace
// TODO: Encode no alias attribute as part of the function signature conversion or as a
// separate rewrite pattern. Retrieve new function after signature conversion.
insertNoAliasArgAttrs();
// TODO: To be enabled in follow-up commit.
// insertNoAliasArgAttrs();
}
opAttrsVec = m_attrsVec;
......@@ -492,22 +520,22 @@ namespace
/// Add llvm.noalias attribute to all the memref function arguments. We know that this is safe
/// by nGraph op semantics.
void DialectLoweringPass::insertNoAliasArgAttrs()
{
FuncOp func = getModule().lookupSymbol<mlir::FuncOp>(funcName);
NGRAPH_CHECK(func, "FuncOp '" + funcName.str() + "' not found");
unsigned int argIdx = 0;
for (auto arg : func.getArguments())
{
if (arg.getType().isa<MemRefType>())
{
func.setArgAttr(argIdx, "llvm.noalias", BoolAttr::get(true, &getContext()));
}
++argIdx;
}
}
// void DialectLoweringPass::insertNoAliasArgAttrs()
//{
// FuncOp func = getModule().lookupSymbol<mlir::FuncOp>(funcName);
// NGRAPH_CHECK(func, "FuncOp '" + funcName.str() + "' not found");
// unsigned int argIdx = 0;
// for (auto arg : func.getArguments())
// {
// if (arg.getType().isa<MemRefType>())
// {
// func.setArgAttr(argIdx, "llvm.noalias", BoolAttr::get(true, &getContext()));
// }
// ++argIdx;
// }
//}
void DialectLoweringPass::insertDeallocs(PatternRewriter& rewriter)
{
......@@ -543,40 +571,6 @@ namespace
return m_attrsVec.size() - 1;
}
// NGDialect converters
Type NGraphTypeConverter::convertType(Type type)
{
// We may need to refactor this code to a external utility if type conversion is needed
// outside of the lowering context since NGraphTypeConverter is private.
if (auto tensorType = type.dyn_cast<NGTensorType>())
{
// Convert NGTensorType to Std MemRefType directly instead of going to Std TensorType.
// This may change in the future.
return MemRefType::get(tensorType.getShape(),
convertType(tensorType.getElementType()),
{/* no map used */},
0);
}
if (auto floatType = type.dyn_cast<NGFloatType>())
{
// Float types are already std type.
return floatType;
}
if (auto intType = type.dyn_cast<NGIntegerType>())
{
return mlir::IntegerType::get(intType.getWidth(), intType.getContext());
}
if (auto boolType = type.dyn_cast<NGBoolType>())
{
return mlir::IntegerType::get(1 /* width */, boolType.getContext());
}
// Do not assert/NGRAPH_CHECK here. Type convertion infra expects `convertType` to return
// the input type if the type is not supported.
return type;
}
#define REWRITER(OP) \
PatternMatchResult OP##Conversion::matchAndRewrite( \
Operation* op, ArrayRef<Value> operands, ConversionPatternRewriter& rewriter) const
......@@ -680,15 +674,15 @@ namespace
ScopedContext scope(rewriter, loc);
// Views
MemRefView vRes(result), vLHS(lhs);
MemRefBoundsCapture vRes(result), vLHS(lhs);
// Index Values
IndexedValue iRes(result), iLHS(lhs);
AffineIndexedValue iRes(result), iLHS(lhs);
// Bounds Index Handles
auto lbs = vLHS.getLbs();
auto ubs = vLHS.getUbs();
// Loop induction vars
auto ivs = makeIndexHandles(vLHS.rank());
auto pivs = makeHandlePointers(MutableArrayRef<IndexHandle>(ivs));
auto ivs = ValueHandle::makeIndexHandles(vLHS.rank());
auto pivs = makeHandlePointers(ivs);
// Steps
auto steps = vLHS.getSteps();
......@@ -698,7 +692,7 @@ namespace
AffineLoopNestBuilder(pivs, lbs, ubs, steps)([&] {
ValueHandle val = iLHS(ivs);
ValueHandle zero = createZeroConstant(elemTy);
iRes(ivs) = intrinsics::select(val > zero, val, zero);
iRes(ivs) = std_select(val > zero, val, zero);
});
rewriter.replaceOp(op, {result});
......@@ -742,36 +736,37 @@ namespace
// res[n, k] += lhs[n, m] * rhs[m, k]
// TODO (dcab): We currently generate a super naive loop nest. Improve loop nest layout.
MemRefView vRes(result), vLhs(lhs), vRhs(rhs);
MemRefBoundsCapture vRes(result), vLhs(lhs), vRhs(rhs);
NGRAPH_CHECK(vLhs.rank() == 2 && vRhs.rank() == 2 && vRes.rank() == 2,
"Dot operation is only supported for 2D tensors");
// Create induction variables, lower bounds, upper bounds and steps of the loop nest.
// It's important to note that MemRefView priovides lb/ub/step info is "reverse order",
// i.e., fastest varying dimension is the last one, slowest varying dimention is the first
// one.
IndexHandle n, m, k;
// It's important to note that MemRefBoundsCapture priovides lb/ub/step info is "reverse
// order", i.e., fastest varying dimension is the last one, slowest varying dimention is the
// first one.
auto indexType = IndexType::get(rewriter.getContext());
ValueHandle n(indexType), m(indexType), k(indexType);
unsigned nDim = vLhs.fastestVarying() - 1;
unsigned mDim = vRhs.fastestVarying();
unsigned kDim = vRhs.fastestVarying();
IndexHandle nLb(vLhs.lb(nDim)), mLb(vLhs.lb(mDim)), kLb(vRhs.lb(kDim));
IndexHandle nUb(vLhs.ub(nDim)), mUb(vLhs.ub(mDim)), kUb(vRhs.ub(kDim));
ValueHandle nLb(vLhs.lb(nDim)), mLb(vLhs.lb(mDim)), kLb(vRhs.lb(kDim));
ValueHandle nUb(vLhs.ub(nDim)), mUb(vLhs.ub(mDim)), kUb(vRhs.ub(kDim));
int64_t nStep = vLhs.step(nDim), mStep = vLhs.step(mDim), kStep = vRhs.step(kDim);
// Constants and indexed values to be used inside the loop nest.
IndexedValue iRes(result), iLhs(lhs), iRhs(rhs);
AffineIndexedValue iRes(result), iLhs(lhs), iRhs(rhs);
ValueHandle zeroInit(rewriter.create<ConstantOp>(loc, rewriter.getZeroAttr(elemTy)));
{
IndexHandle n, k;
LoopBuilder::makeAffine(&n, nLb, nUb, nStep)([&] {
LoopBuilder::makeAffine(&k, kLb, kUb, kStep)([&] { iRes(n, k) = zeroInit; });
ValueHandle n(indexType), k(indexType);
makeAffineLoopBuilder(&n, nLb, nUb, nStep)([&] {
makeAffineLoopBuilder(&k, kLb, kUb, kStep)([&] { iRes(n, k) = zeroInit; });
});
}
LoopBuilder::makeAffine(&n, nLb, nUb, nStep)([&] {
LoopBuilder::makeAffine(&m, mLb, mUb, mStep)([&] {
LoopBuilder::makeAffine(&k, kLb, kUb, kStep)(
makeAffineLoopBuilder(&n, nLb, nUb, nStep)([&] {
makeAffineLoopBuilder(&m, mLb, mUb, mStep)([&] {
makeAffineLoopBuilder(&k, kLb, kUb, kStep)(
[&] { iRes(n, k) += iLhs(n, m) * iRhs(m, k); });
});
});
......@@ -792,13 +787,13 @@ namespace
NGRAPH_CHECK(result, "Unexpected null result in ConcatOp");
// Create view to write into result.
MemRefView vRes(result);
MemRefBoundsCapture vRes(result);
auto rank = vRes.rank();
// For each operand, generate a separate loop to copy into the target slice of "result".
// We'll keep track of the slice offsets via concatenation_axis_pos.
auto concatenationAxis = concat.concatenation_axis().getSExtValue();
IndexHandle concatenationAxisPos(index_type(0));
Value concatenationAxisPos(std_constant_index(0));
for (auto& operand : operands)
{
......@@ -817,7 +812,7 @@ namespace
// [i_(r-2)][i_(r-1)]
// :=
// operand[i_0][i_1]...[i_(r-2)][i_(r-1)]
MemRefView vOperand(operand);
MemRefBoundsCapture vOperand(operand);
NGRAPH_CHECK(vOperand.rank() == rank, "Unexpected rank mismatch");
llvm::SmallVector<ValueHandle, 5> indexVars;
......@@ -825,9 +820,10 @@ namespace
llvm::SmallVector<ValueHandle, 5> indexVarLbs;
llvm::SmallVector<ValueHandle, 5> indexVarUbs;
llvm::SmallVector<int64_t, 5> indexVarSteps;
auto indexType = IndexType::get(rewriter.getContext());
for (int i = 0; i < rank; i++)
{
indexVars.push_back(IndexHandle());
indexVars.push_back(ValueHandle(indexType));
indexVarPtrs.push_back(&(indexVars.back()));
indexVarLbs.push_back(vOperand.lb(i));
indexVarUbs.push_back(vOperand.ub(i));
......@@ -835,15 +831,15 @@ namespace
}
AffineLoopNestBuilder(indexVarPtrs, indexVarLbs, indexVarUbs, indexVarSteps)([&] {
IndexedValue ivRes(result);
IndexedValue ivOperand(operand);
AffineIndexedValue ivRes(result);
AffineIndexedValue ivOperand(operand);
// On the LHS of the assignment, adjust the index for the concatenation axis.
llvm::SmallVector<ValueHandle, 5> resIndexHandles;
for (int i = 0; i < rank; i++)
{
resIndexHandles.push_back(i == concatenationAxis
? indexVars[i] + concatenationAxisPos
? indexVars[i] + ValueHandle(concatenationAxisPos)
: indexVars[i]);
}
......@@ -851,11 +847,11 @@ namespace
});
// Move up concatenation_axis_pos for the next operand.
concatenationAxisPos = concatenationAxisPos + vOperand.ub(concatenationAxis);
concatenationAxisPos =
ValueHandle(concatenationAxisPos) + vOperand.ub(concatenationAxis);
}
rewriter.replaceOp(op, {result});
return matchSuccess();
}
......@@ -874,14 +870,13 @@ namespace
auto axis = gatherOp.axis().getSExtValue();
// Create view to write into result.
MemRefView vRes(result), vParams(params), vIndices(indices);
MemRefBoundsCapture vRes(result), vParams(params), vIndices(indices);
// Indexed Values
IndexedValue iRes(result), iIndices(indices);
AffineIndexedValue iRes(result), iIndices(indices);
StdIndexedValue iParams(params);
// Construct outer loop for params dims. Exclude the axis dim.
SmallVector<ValueHandle, 4> paramsLbs, paramsUbs;
SmallVector<IndexHandle, 4> paramsIVs;
SmallVector<ValueHandle, 4> paramsLbs, paramsUbs, paramsIVs;
SmallVector<int64_t, 4> paramsSteps;
SmallVector<ValueHandle*, 4> paramsIVPtrs;
for (auto i = 0; i < vParams.rank(); i++)
......@@ -889,8 +884,8 @@ namespace
// skip gather axis
if (i == axis)
continue;
paramsLbs.push_back(IndexHandle(vParams.lb(i)));
paramsUbs.push_back(IndexHandle(vParams.ub(i)));
paramsLbs.push_back(vParams.lb(i));
paramsUbs.push_back(vParams.ub(i));
paramsSteps.push_back(vParams.step(i));
}
NGRAPH_CHECK(paramsLbs.size() == vParams.rank() - 1 &&
......@@ -898,17 +893,17 @@ namespace
paramsSteps.size() == paramsLbs.size(),
"Incorrect loop nest bounds size for gather params");
paramsIVs = makeIndexHandles(vParams.rank() - 1);
paramsIVPtrs = makeHandlePointers(MutableArrayRef<IndexHandle>(paramsIVs));
paramsIVs = ValueHandle::makeIndexHandles(vParams.rank() - 1);
paramsIVPtrs = makeHandlePointers(paramsIVs);
auto indicesLbs = vIndices.getLbs();
auto indicesUbs = vIndices.getUbs();
auto indicesSteps = vIndices.getSteps();
auto indicesIVs = makeIndexHandles(vIndices.rank());
auto indicesIVPtrs = makeHandlePointers(MutableArrayRef<IndexHandle>(indicesIVs));
auto indicesIVs = ValueHandle::makeIndexHandles(vIndices.rank());
auto indicesIVPtrs = makeHandlePointers(indicesIVs);
SmallVector<IndexHandle, 8> paramsIndices, resIndices;
SmallVector<ValueHandle, 8> paramsIndices, resIndices;
// Make sure we are going to create loops
NGRAPH_CHECK(vParams.rank() > 0, "Invalid size for indices steps");
......@@ -946,7 +941,7 @@ namespace
{
if (i == axis)
{
paramsIndices.push_back(IndexHandle(axisIdx));
paramsIndices.push_back(axisIdx);
}
else
{
......@@ -1022,10 +1017,10 @@ namespace
NGRAPH_CHECK(groups > 0, "Invalid number of groups");
// create outer group convolution loop
// for group = 0 to groups
IndexHandle iv;
ValueHandle lb = intrinsics::constant_index(0);
ValueHandle ub = intrinsics::constant_index(groups);
auto indexType = IndexType::get(rewriter.getContext());
ValueHandle iv(indexType);
ValueHandle lb = std_constant_index(0);
ValueHandle ub = std_constant_index(groups);
auto imagesType = images.getType().cast<MemRefType>();
auto filtersType = filters.getType().cast<MemRefType>();
......@@ -1043,13 +1038,13 @@ namespace
NGRAPH_CHECK(groupsInFilters || filtersShape[0] % groups == 0,
"Filters dim is not divisible by number of groups");
auto channelGroupSize = intrinsics::constant_index(imagesShape[1] / groups);
auto filtersGroupSize = intrinsics::constant_index(
groupsInFilters ? filtersShape[1] : filtersShape[0] / groups);
auto channelGroupSize = std_constant_index(imagesShape[1] / groups);
auto filtersGroupSize =
std_constant_index(groupsInFilters ? filtersShape[1] : filtersShape[0] / groups);
NGRAPH_CHECK(!groupsInFilters || groups == filtersShape[0]);
LoopBuilder::makeAffine(&iv, lb, ub, 1)([&] {
makeAffineLoopBuilder(&iv, lb, ub, 1)([&] {
// lower/upper bounds on image channel dim and kernels dim
auto cLb = iv * channelGroupSize;
auto cUb = cLb + channelGroupSize;
......@@ -1152,7 +1147,7 @@ namespace
castMemRef(inputs, outputs, rewriter, unrankedMemrefTy);
FuncOp callBackFunc = pass.getCallDecl(
"__mlir_callback_2_inputs",
"callback_2_inputs",
{unrankedMemrefTy, unrankedMemrefTy, unrankedMemrefTy, int64Ty, int64Ty},
{},
rewriter);
......@@ -1245,7 +1240,7 @@ namespace
auto int64Ty = rewriter.getIntegerType(64);
auto unrankedMemrefTy = UnrankedMemRefType::get(elemTy, 0);
auto callBackFunc = pass.getCallDecl(
"__mlir_callback_2_inputs",
"callback_2_inputs",
{unrankedMemrefTy, unrankedMemrefTy, unrankedMemrefTy, int64Ty, int64Ty},
{},
rewriter);
......@@ -1297,7 +1292,7 @@ namespace
elemTy == biasTy.getElementType(),
"Types mismatch in GemmOp");
MemRefView vRes(result), vLhs(lhs), vRhs(rhs), vBias(bias);
MemRefBoundsCapture vRes(result), vLhs(lhs), vRhs(rhs), vBias(bias);
NGRAPH_CHECK(vLhs.rank() == 2 && vRhs.rank() == 2 && vRes.rank() == 2 && vBias.rank() <= 2,
"Gemm operation is only supported for 2D tensors");
......@@ -1361,7 +1356,7 @@ namespace
auto int64Ty = rewriter.getIntegerType(64);
auto unrankedMemrefTy = UnrankedMemRefType::get(elemTy, 0);
auto callBackFunc = pass.getCallDecl("__mlir_callback_3_inputs",
auto callBackFunc = pass.getCallDecl("callback_3_inputs",
{unrankedMemrefTy,
unrankedMemrefTy,
unrankedMemrefTy,
......@@ -1425,7 +1420,7 @@ namespace
rewriter.getUnknownLoc(), static_cast<int64_t>(OpType::SOFTMAX), 64);
FuncOp callBackFunc =
pass.getCallDecl("__mlir_callback_1_input",
pass.getCallDecl("callback_1_input",
{unrankedMemrefTy, unrankedMemrefTy, int64Ty, int64Ty},
{},
rewriter);
......@@ -1511,11 +1506,12 @@ namespace
auto padBelow = padBelowAttr.getValue();
auto padAbove = padBelowAttr.getValue();
Type elemTy = images.getType().cast<MemRefType>().getElementType();
auto indexType = IndexType::get(rewriter.getContext());
// Create views
MemRefView vRes(result), vImages(images), vFilters(filters);
MemRefBoundsCapture vRes(result), vImages(images), vFilters(filters);
// Create indexed Values
IndexedValue iRes(result), iImages(images), iFilters(filters);
AffineIndexedValue iRes(result), iImages(images), iFilters(filters);
// Bounds on batch size N
ValueHandle batchLb = vImages.lb(0), batchUb = vImages.ub(0);
// Bounds on spatial dimensions
......@@ -1526,9 +1522,8 @@ namespace
unsigned spatialRank = vImages.rank() - 2;
// Result spatial indices and bounds
auto resSpatialIndices = makeIndexHandles(spatialRank);
auto resSpatialIndicesPtrs =
makeHandlePointers(MutableArrayRef<IndexHandle>(resSpatialIndices));
auto resSpatialIndices = ValueHandle::makeIndexHandles(spatialRank);
auto resSpatialIndicesPtrs = makeHandlePointers(resSpatialIndices);
SmallVector<int64_t, 4> resSteps, filtersSteps;
SmallVector<int, 4> padBelowIntValues;
bool withPadding = false;
......@@ -1610,9 +1605,8 @@ namespace
"Results spatial dims mismatches input");
// Filters spatial indices and bounds
auto filtersSpatialIndices = makeIndexHandles(spatialRank);
auto filtersSpatialIndicesPtrs =
makeHandlePointers(MutableArrayRef<IndexHandle>(filtersSpatialIndices));
auto filtersSpatialIndices = ValueHandle::makeIndexHandles(spatialRank);
auto filtersSpatialIndicesPtrs = makeHandlePointers(filtersSpatialIndices);
for (auto i = 0; i < spatialRank; i++)
{
......@@ -1658,23 +1652,22 @@ namespace
// Initialize output to zero
{
IndexHandle n, k, c;
auto resSpatialIndices = makeIndexHandles(spatialRank);
auto resSpatialIndicesPtrs =
makeHandlePointers(MutableArrayRef<IndexHandle>(resSpatialIndices));
ValueHandle n(indexType), k(indexType), c(indexType);
auto resSpatialIndices = ValueHandle::makeIndexHandles(spatialRank);
auto resSpatialIndicesPtrs = makeHandlePointers(resSpatialIndices);
LoopBuilder::makeAffine(&n, batchLb, batchUb, 1)([&] {
LoopBuilder::makeAffine(&k, numFiltersLb, numFiltersUb, 1)([&] {
makeAffineLoopBuilder(&n, batchLb, batchUb, 1)([&] {
makeAffineLoopBuilder(&k, numFiltersLb, numFiltersUb, 1)([&] {
AffineLoopNestBuilder(
resSpatialIndicesPtrs, resSpatialLbs, resSpatialUbs, resSteps)([&] {
SmallVector<IndexHandle, 4> resIndices;
SmallVector<ValueHandle, 4> resIndices;
// Result indices
resIndices.push_back(n);
if (groupConvolution && groupsInFilters)
{
// compute global C_OUT from gID and k
// gId * C_OUT (num of filters) + k
resIndices.push_back(IndexHandle(ValueHandle(gId) * numFiltersUb + k));
resIndices.push_back(ValueHandle(gId) * numFiltersUb + k);
}
else
{
......@@ -1689,31 +1682,31 @@ namespace
});
}
IndexHandle n, k, c;
ValueHandle n(indexType), k(indexType), c(indexType);
// Convolution loop
LoopBuilder::makeAffine(&n, batchLb, batchUb, 1)([&] {
makeAffineLoopBuilder(&n, batchLb, batchUb, 1)([&] {
// Number of filters loop
LoopBuilder::makeAffine(&k, numFiltersLb, numFiltersUb, 1)([&] {
makeAffineLoopBuilder(&k, numFiltersLb, numFiltersUb, 1)([&] {
// Channels loop
LoopBuilder::makeAffine(&c, numChannelsLb, numChannelsUb, 1)([&] {
makeAffineLoopBuilder(&c, numChannelsLb, numChannelsUb, 1)([&] {
// Results loop
AffineLoopNestBuilder(
resSpatialIndicesPtrs, resSpatialLbs, resSpatialUbs, resSteps)([&] {
// Compute image start indices
SmallVector<IndexHandle, 4> imgStartIndices;
SmallVector<ValueHandle, 4> imgStartIndices;
for (auto i = 0; i < spatialRank; i++)
{
IntegerAttr iAttr = strides[i].cast<IntegerAttr>();
auto stride = intrinsics::constant_index(iAttr.getInt());
imgStartIndices.push_back(IndexHandle(resSpatialIndices[i] * stride));
auto stride = std_constant_index(iAttr.getInt());
imgStartIndices.push_back(resSpatialIndices[i] * stride);
}
SmallVector<IndexHandle, 4> resIndices;
SmallVector<ValueHandle, 4> resIndices;
// Result indices
resIndices.push_back(n);
if (groupConvolution && groupsInFilters)
{
// gId * C_OUT (num of filters) + k
resIndices.push_back(IndexHandle(ValueHandle(gId) * numFiltersUb + k));
resIndices.push_back(ValueHandle(gId) * numFiltersUb + k);
}
else
{
......@@ -1727,15 +1720,14 @@ namespace
filtersSpatialLbs,
filtersSpatialUbs,
filtersSteps)([&] {
SmallVector<IndexHandle, 4> imgIndices, filtersIndices;
SmallVector<ValueHandle, 4> imgIndices, filtersIndices;
// Image indices
// Here we compute the virtual start index into the padded image.
imgIndices.push_back(n);
imgIndices.push_back(c);
for (auto i = 0; i < spatialRank; i++)
{
imgIndices.push_back(
IndexHandle(imgStartIndices[i] + filtersSpatialIndices[i]));
imgIndices.push_back(imgStartIndices[i] + filtersSpatialIndices[i]);
}
// Filter indices
......@@ -1744,14 +1736,14 @@ namespace
// index
if (groupConvolution && groupsInFilters)
{
filtersIndices.push_back(IndexHandle(gId));
filtersIndices.push_back(ValueHandle(gId));
}
filtersIndices.push_back(k);
// subtract lower bound of channel
// if we are doing group convolution this bound will advance based
// on the group id. For the filters, it should always start from 0
filtersIndices.push_back(IndexHandle(c - numChannelsLb));
filtersIndices.push_back(c - numChannelsLb);
filtersIndices.insert(filtersIndices.end(),
filtersSpatialIndices.begin(),
filtersSpatialIndices.end());
......@@ -1759,7 +1751,7 @@ namespace
if (withPadding)
{
// if args : img dims, img lbs, img ubs
SmallVector<IndexHandle, 4>::iterator it = imgIndices.begin();
SmallVector<ValueHandle, 4>::iterator it = imgIndices.begin();
std::advance(it, 2);
SmallVector<Value, 4> affineIfArgs(it, imgIndices.end());
affineIfArgs.insert(
......@@ -1777,14 +1769,14 @@ namespace
ScopedContext scope(rewriter, loc);
// We must subtract pad below before img load, since the
// physical image is not padded
SmallVector<IndexHandle, 4> adjustedImgIndices;
SmallVector<ValueHandle, 4> adjustedImgIndices;
adjustedImgIndices.push_back(n);
adjustedImgIndices.push_back(c);
for (auto i = 0; i < spatialRank; i++)
{
adjustedImgIndices.push_back(IndexHandle(
adjustedImgIndices.push_back(
imgIndices[2 + i] -
intrinsics::constant_index(padBelowIntValues[i])));
std_constant_index(padBelowIntValues[i]));
}
iRes(resIndices) =
iRes(resIndices) +
......@@ -1821,15 +1813,15 @@ namespace
ScopedContext scope(rewriter, loc);
// Views
MemRefView vRes(result), vLHS(lhs);
MemRefBoundsCapture vRes(result), vLHS(lhs);
// Index Values
IndexedValue iRes(result), iLHS(lhs);
AffineIndexedValue iRes(result), iLHS(lhs);
// Bounds Index Handles
auto lbs = vLHS.getLbs();
auto ubs = vLHS.getUbs();
// Loop induction vars
auto ivs = makeIndexHandles(vLHS.rank());
auto pivs = makeHandlePointers(MutableArrayRef<IndexHandle>(ivs));
auto ivs = ValueHandle::makeIndexHandles(vLHS.rank());
auto pivs = makeHandlePointers(ivs);
// Steps
auto steps = vLHS.getSteps();
......@@ -1867,15 +1859,15 @@ namespace
ScopedContext scope(rewriter, loc);
// Views
MemRefView vRes(result), vLHS(lhs), vRHS(rhs);
MemRefBoundsCapture vRes(result), vLHS(lhs), vRHS(rhs);
// Index Values
IndexedValue iRes(result), iLHS(lhs), iRHS(rhs);
AffineIndexedValue iRes(result), iLHS(lhs), iRHS(rhs);
// Bounds Index Handles
auto lbs = vLHS.getLbs();
auto ubs = vLHS.getUbs();
// Loop induction vars
auto ivs = makeIndexHandles(vLHS.rank());
auto pivs = makeHandlePointers(MutableArrayRef<IndexHandle>(ivs));
auto ivs = ValueHandle::makeIndexHandles(vLHS.rank());
auto pivs = makeHandlePointers(ivs);
// Steps
auto steps = vLHS.getSteps();
// element type of the operand
......@@ -1900,65 +1892,57 @@ namespace
iRes(ivs) = iLHS(ivs) / iRHS(ivs);
}
// TODO(pthoreho) For all comparision operators, use
// edsc::intrinsics::zero_extendi(ValueHandle(iLHS(ivs)) !=
// zero_extendi(ValueHandle(iLHS(ivs)) !=
// ValueHandle(iRHS(ivs)), IntegerType::get(8, op->getContext()));
// instead of edsc::intrinsics::select once `zero_extendi` is
// instead of std_select once `zero_extendi` is
// made available in the edsc::intrinsics namescope in MLIR repo.
else if (isa<NGGreaterOp>(op))
{
iRes(ivs) =
edsc::intrinsics::select(ValueHandle(iLHS(ivs)) > ValueHandle(iRHS(ivs)),
createOneConstant(elemTy),
createZeroConstant(elemTy));
iRes(ivs) = std_select(ValueHandle(iLHS(ivs)) > ValueHandle(iRHS(ivs)),
createOneConstant(elemTy),
createZeroConstant(elemTy));
}
else if (isa<NGLessOp>(op))
{
iRes(ivs) =
edsc::intrinsics::select(ValueHandle(iLHS(ivs)) < ValueHandle(iRHS(ivs)),
createOneConstant(elemTy),
createZeroConstant(elemTy));
iRes(ivs) = std_select(ValueHandle(iLHS(ivs)) < ValueHandle(iRHS(ivs)),
createOneConstant(elemTy),
createZeroConstant(elemTy));
}
else if (isa<NGGreaterEqOp>(op))
{
iRes(ivs) =
edsc::intrinsics::select(ValueHandle(iLHS(ivs)) >= ValueHandle(iRHS(ivs)),
createOneConstant(elemTy),
createZeroConstant(elemTy));
iRes(ivs) = std_select(ValueHandle(iLHS(ivs)) >= ValueHandle(iRHS(ivs)),
createOneConstant(elemTy),
createZeroConstant(elemTy));
}
else if (isa<NGLessEqOp>(op))
{
iRes(ivs) =
edsc::intrinsics::select(ValueHandle(iLHS(ivs)) <= ValueHandle(iRHS(ivs)),
createOneConstant(elemTy),
createZeroConstant(elemTy));
iRes(ivs) = std_select(ValueHandle(iLHS(ivs)) <= ValueHandle(iRHS(ivs)),
createOneConstant(elemTy),
createZeroConstant(elemTy));
}
else if (isa<NGEqOp>(op))
{
iRes(ivs) =
edsc::intrinsics::select(ValueHandle(iLHS(ivs)) == ValueHandle(iRHS(ivs)),
createOneConstant(elemTy),
createZeroConstant(elemTy));
iRes(ivs) = std_select(ValueHandle(iLHS(ivs)) == ValueHandle(iRHS(ivs)),
createOneConstant(elemTy),
createZeroConstant(elemTy));
}
else if (isa<NGNotEqOp>(op))
{
iRes(ivs) =
edsc::intrinsics::select(ValueHandle(iLHS(ivs)) != ValueHandle(iRHS(ivs)),
createOneConstant(elemTy),
createZeroConstant(elemTy));
iRes(ivs) = std_select(ValueHandle(iLHS(ivs)) != ValueHandle(iRHS(ivs)),
createOneConstant(elemTy),
createZeroConstant(elemTy));
}
else if (isa<NGMaxOp>(op))
{
iRes(ivs) =
edsc::intrinsics::select(ValueHandle(iLHS(ivs)) > ValueHandle(iRHS(ivs)),
ValueHandle(iLHS(ivs)),
ValueHandle(iRHS(ivs)));
iRes(ivs) = std_select(ValueHandle(iLHS(ivs)) > ValueHandle(iRHS(ivs)),
ValueHandle(iLHS(ivs)),
ValueHandle(iRHS(ivs)));
}
else if (isa<NGMinOp>(op))
{
iRes(ivs) =
edsc::intrinsics::select(ValueHandle(iLHS(ivs)) < ValueHandle(iRHS(ivs)),
ValueHandle(iLHS(ivs)),
ValueHandle(iRHS(ivs)));
iRes(ivs) = std_select(ValueHandle(iLHS(ivs)) < ValueHandle(iRHS(ivs)),
ValueHandle(iLHS(ivs)),
ValueHandle(iRHS(ivs)));
}
else
{
......@@ -1995,10 +1979,10 @@ namespace
Value result = pass.buildOutputDefs(op, rewriter)[0];
// Views
MemRefView vRes(result), vArg(arg);
MemRefBoundsCapture vRes(result), vArg(arg);
// Index Values
StdIndexedValue iRes(result), stdArg(arg);
IndexedValue affineArg(arg);
AffineIndexedValue affineArg(arg);
// Bounds Index Handles
auto resLbs = vRes.getLbs();
auto resUbs = vRes.getUbs();
......@@ -2008,8 +1992,8 @@ namespace
Type resTy = result.getType().cast<MemRefType>().getElementType();
// Generate loop nest that initializes result to lower bound of the axis to be reduced.
{
auto ivs = makeIndexHandles(vRes.rank());
auto pivs = makeHandlePointers(MutableArrayRef<IndexHandle>(ivs));
auto ivs = ValueHandle::makeIndexHandles(vRes.rank());
auto pivs = makeHandlePointers(ivs);
auto steps = vRes.getSteps();
auto initVal = vArg.lb(axis);
AffineLoopNestBuilder(pivs, resLbs, resUbs, steps)(
......@@ -2018,10 +2002,10 @@ namespace
// Generate loop nest that computes the actual index reduction.
{
auto allIVs = makeIndexHandles(vArg.rank());
auto pAllIVs = makeHandlePointers(MutableArrayRef<IndexHandle>(allIVs));
auto allIVs = ValueHandle::makeIndexHandles(vArg.rank());
auto pAllIVs = makeHandlePointers(allIVs);
auto steps = vArg.getSteps();
SmallVector<IndexHandle, 8> nonRedIVs;
SmallVector<ValueHandle, 8> nonRedIVs;
Type resTy = result.getType().cast<MemRefType>().getElementType();
NGRAPH_CHECK(resTy.isa<IntegerType>(),
......@@ -2049,10 +2033,8 @@ namespace
// Select the min/max value and cast it back to integer type before storing it.
ValueHandle newRedIdx =
std::is_same<RedOp, NGArgMinRedOp>()
? edsc::intrinsics::select(
affineArg(allIVs) < stdArg(tempIVs), allIVs[axis], currRedIdx)
: edsc::intrinsics::select(
stdArg(tempIVs) < affineArg(allIVs), allIVs[axis], currRedIdx);
? std_select(affineArg(allIVs) < stdArg(tempIVs), allIVs[axis], currRedIdx)
: std_select(stdArg(tempIVs) < affineArg(allIVs), allIVs[axis], currRedIdx);
iRes(nonRedIVs) = ValueHandle::create<IndexCastOp>(newRedIdx, resTy);
});
......@@ -2123,7 +2105,7 @@ namespace
castMemRef(inputs, outputs, rewriter, unrankedMemrefTy);
FuncOp callBackFunc =
pass.getCallDecl("__mlir_callback_1_input",
pass.getCallDecl("callback_1_input",
{unrankedMemrefTy, unrankedMemrefTy, int64Ty, int64Ty},
{},
rewriter);
......@@ -2168,11 +2150,11 @@ namespace
{
if (floatTy.isF32())
{
return intrinsics::constant_float(llvm::APFloat(0.0f), floatTy);
return std_constant_float(llvm::APFloat(0.0f), floatTy);
}
else if (floatTy.isF64())
{
return intrinsics::constant_float(llvm::APFloat(0.0), floatTy);
return std_constant_float(llvm::APFloat(0.0), floatTy);
}
else
{
......@@ -2181,7 +2163,7 @@ namespace
}
else if (auto intTy = type.dyn_cast<IntegerType>())
{
return intrinsics::constant_int(0, intTy.getWidth());
return std_constant_int(0, intTy.getWidth());
}
NGRAPH_UNREACHABLE("Unsupported type");
}
......@@ -2192,11 +2174,11 @@ namespace
{
if (floatTy.isF32())
{
return intrinsics::constant_float(llvm::APFloat(1.0f), floatTy);
return std_constant_float(llvm::APFloat(1.0f), floatTy);
}
else if (floatTy.isF64())
{
return intrinsics::constant_float(llvm::APFloat(1.0f), floatTy);
return std_constant_float(llvm::APFloat(1.0f), floatTy);
}
else
{
......@@ -2205,7 +2187,7 @@ namespace
}
else if (auto intTy = type.dyn_cast<IntegerType>())
{
return intrinsics::constant_int(1, intTy.getWidth());
return std_constant_int(1, intTy.getWidth());
}
NGRAPH_UNREACHABLE("Unsupported type");
}
......
......@@ -23,9 +23,7 @@
#include "contrib/mlir/core/ngraph_dialect/type.hpp"
#include <llvm/IR/Module.h>
#include <mlir/EDSC/Builders.h>
#include <mlir/EDSC/Helpers.h>
#include <mlir/EDSC/Intrinsics.h>
#include <mlir/Dialect/AffineOps/EDSC/Builders.h>
#include <mlir/IR/IntegerSet.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/IR/StandardTypes.h>
......
......@@ -719,7 +719,7 @@ static void __mlir_cblas_sgemm_with_bias(StaticMemRef* memRefmatA,
}
}
extern "C" void __mlir_callback_1_input(void* input, void* output, size_t index, OpType type)
extern "C" void _mlir_ciface_callback_1_input(void* input, void* output, size_t index, OpType type)
{
auto unrankedMemRefInput = reinterpret_cast<UnrankedMemRef*>(input);
auto unrankedMemRefOutput = reinterpret_cast<UnrankedMemRef*>(output);
......@@ -752,8 +752,8 @@ extern "C" void __mlir_callback_1_input(void* input, void* output, size_t index,
}
}
extern "C" void
__mlir_callback_2_inputs(void* input0, void* input1, void* output, size_t index, OpType type)
extern "C" void _mlir_ciface_callback_2_inputs(
void* input0, void* input1, void* output, size_t index, OpType type)
{
auto unrankedMemRefInput0 = reinterpret_cast<UnrankedMemRef*>(input0);
auto unrankedMemRefInput1 = reinterpret_cast<UnrankedMemRef*>(input1);
......@@ -780,7 +780,7 @@ extern "C" void
}
}
extern "C" void __mlir_callback_3_inputs(
extern "C" void _mlir_ciface_callback_3_inputs(
void* input0, void* input1, void* input2, void* output, size_t index, OpType type)
{
auto unrankedMemRefInput0 = reinterpret_cast<UnrankedMemRef*>(input0);
......
......@@ -83,7 +83,7 @@ void MLIRCPURuntime::bindArguments(const std::vector<MemRefArg>& args)
{
NGRAPH_CHECK(m_module, "MLIR module is not ready.");
auto func = m_module->lookupSymbol<mlir::LLVM::LLVMFuncOp>("main");
auto func = m_module->lookupSymbol<mlir::LLVM::LLVMFuncOp>("_mlir_ciface_main");
NGRAPH_CHECK(func && !func.getBlocks().empty(), "Function not found");
// Set external arguments
......@@ -127,14 +127,15 @@ void MLIRCPURuntime::execute()
// uniformity reasons, it takes a list of type-erased pointers to arguments.
// Please, note that 'invoke' method is overloaded with a parameter pack version.
// Make sure the MutableArrayRef version is invoked.
auto invocationResult = m_engine->invoke("main", llvm::MutableArrayRef<void*>(m_invokeArgs));
auto invocationResult =
m_engine->invoke("_mlir_ciface_main", llvm::MutableArrayRef<void*>(m_invokeArgs));
if (clDumpObjectFile)
{
m_engine->dumpToObjectFile(clObjectFilename.empty() ? "jitted_mlir.o"
: clObjectFilename.getValue());
}
NGRAPH_CHECK(!invocationResult, "JIT invocation of 'main' failed\n");
NGRAPH_CHECK(!invocationResult, "JIT invocation of '_mlir_ciface_main' failed\n");
}
void MLIRCPURuntime::cleanup()
......
......@@ -16,7 +16,7 @@
set(LIBS
mlir_backend
MLIROptMain
MLIROptLib
MLIRPass
MLIRParser
LLVMSupport
......
......@@ -21,10 +21,21 @@
#include "contrib/mlir/core/ngraph_dialect/dialect.hpp"
#include <llvm/Support/CommandLine.h>
#include <llvm/Support/Debug.h>
#include <mlir/Dialect/AffineOps/AffineOps.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/Dialect/LoopOps/LoopOps.h>
#include <mlir/Dialect/StandardOps/Ops.h>
#include <mlir/Dialect/VectorOps/VectorOps.h>
#include <mlir/IR/Dialect.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/Pass/Pass.h>
#include <mlir/Transforms/LocationSnapshot.h>
#include <mlir/Transforms/Passes.h>
#include <llvm/Support/CommandLine.h>
#include <llvm/Support/Debug.h>
using namespace mlir;
static llvm::cl::opt<bool> clPrintIRAfterAll(
"ngraph-print-ir-after-all",
......@@ -35,15 +46,47 @@ static llvm::cl::opt<bool> clPrintIRAfterAll(
void ngraph::runtime::ngmlir::initializeNGraphMLIR()
{
// Initialize a dialect only once.
// We currently have no way to query if a dialect is previously
// registered. So using a global flag instead.
static bool init = false;
if (!init)
{
mlir::registerDialect<mlir::NGraphOpsDialect>();
init = true;
}
// Initialize MLIR dialects and passes only once.
static bool init_once = []() {
// In-tree Dialects.
registerDialect<AffineOpsDialect>();
registerDialect<LLVM::LLVMDialect>();
registerDialect<loop::LoopOpsDialect>();
registerDialect<StandardOpsDialect>();
registerDialect<vector::VectorOpsDialect>();
// nGraph dialects.
registerDialect<mlir::NGraphOpsDialect>();
// In-tree passes.
// No-op to avoid DCE on the following pass initializations.
if (std::getenv("bar") != (char*)-1)
return false;
createCanonicalizerPass();
createCSEPass();
createVectorizePass({});
createLoopUnrollPass();
createLoopUnrollAndJamPass();
createSimplifyAffineStructuresPass();
createLoopFusionPass();
createLoopInvariantCodeMotionPass();
createAffineLoopInvariantCodeMotionPass();
createPipelineDataTransferPass();
createLowerAffinePass();
createLoopTilingPass(0);
createLoopCoalescingPass();
createAffineDataCopyGenerationPass(0, 0);
createMemRefDataFlowOptPass();
createStripDebugInfoPass();
createPrintOpStatsPass();
createInlinerPass();
createSymbolDCEPass();
createLocationSnapshotPass({});
return true;
}();
(void)init_once;
}
void ngraph::runtime::ngmlir::dumpMlirModule(const std::string msg, mlir::ModuleOp module)
......
......@@ -10,7 +10,7 @@
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: %0 = memref_cast %arg0 : memref<2x3xf32> to memref<*xf32>
// CHECK: %1 = memref_cast %arg2 : memref<2x3xf32> to memref<*xf32>
// CHECK: call @__mlir_callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
// CHECK: call @callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_softmax(%arg0: !ng.tensor<2x3xf32>, %arg1: !ng.tensor<1x!ng.i64>) -> !ng.tensor<2x3xf32> {
%0 = "ng.softmax"(%arg0) {axes = [0]} : (!ng.tensor<2x3xf32>) -> !ng.tensor<2x3xf32>
"ng.return"(%0) : (!ng.tensor<2x3xf32>) -> ()
......@@ -26,7 +26,7 @@ func @simple_softmax(%arg0: !ng.tensor<2x3xf32>, %arg1: !ng.tensor<1x!ng.i64>) -
// CHECK: %1 = memref_cast %arg1 : memref<6x4xf32> to memref<*xf32>
// CHECK: %2 = memref_cast %arg2 : memref<3x4xf32> to memref<*xf32>
// CHECK: %3 = memref_cast %arg3 : memref<3x4xf32> to memref<*xf32>
// CHECK: call @__mlir_callback_3_inputs(%0, %1, %2, %3, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, memref<*xf32>, i64, i64) -> ()
// CHECK: call @callback_3_inputs(%0, %1, %2, %3, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_gemm(%arg0: !ng.tensor<3x6xf32>, %arg1: !ng.tensor<6x4xf32>, %arg2: !ng.tensor<3x4xf32>) -> !ng.tensor<3x4xf32> {
%0 = "ng.gemm"(%arg0, %arg1, %arg2) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, transA = false, transB = false} : (!ng.tensor<3x6xf32>, !ng.tensor<6x4xf32>, !ng.tensor<3x4xf32>) -> !ng.tensor<3x4xf32>
"ng.return"(%0) : (!ng.tensor<3x4xf32>) -> ()
......@@ -41,7 +41,7 @@ func @simple_gemm(%arg0: !ng.tensor<3x6xf32>, %arg1: !ng.tensor<6x4xf32>, %arg2:
// CHECK: %0 = memref_cast %arg0 : memref<3x2xf32> to memref<*xf32>
// CHECK: %1 = memref_cast %arg1 : memref<2x3xf32> to memref<*xf32>
// CHECK: %2 = memref_cast %arg2 : memref<2x2xf32> to memref<*xf32>
// CHECK: call @__mlir_callback_2_inputs(%0, %1, %2, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, i64, i64) -> ()
// CHECK: call @callback_2_inputs(%0, %1, %2, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_matmul(%arg0: !ng.tensor<3x2xf32>, %arg1: !ng.tensor<2x3xf32>) -> !ng.tensor<2x2xf32> {
%0 = "ng.matmul"(%arg0, %arg1) {transposeA = true, transposeB = true} : (!ng.tensor<3x2xf32>, !ng.tensor<2x3xf32>) -> !ng.tensor<2x2xf32>
"ng.return"(%0) : (!ng.tensor<2x2xf32>) -> ()
......@@ -55,7 +55,7 @@ func @simple_matmul(%arg0: !ng.tensor<3x2xf32>, %arg1: !ng.tensor<2x3xf32>) -> !
// CHECK: %1 = memref_cast %arg1 : memref<2x1x3x3xf32> to memref<*xf32>
// CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: call @__mlir_callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
// CHECK: call @callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_avgpool(%arg0: !ng.tensor<2x1x3x3xf32>) -> !ng.tensor<2x1x3x3xf32> {
%0 = "ng.avgPool"(%arg0) {includePadding = true, padAbove = [1, 1], padBelow = [0, 0], windowMovementStrides = [1, 1], windowShape = [2, 2]} : (!ng.tensor<2x1x3x3xf32>) -> !ng.tensor<2x1x3x3xf32>
"ng.return"(%0) : (!ng.tensor<2x1x3x3xf32>) -> ()
......@@ -69,7 +69,7 @@ func @simple_avgpool(%arg0: !ng.tensor<2x1x3x3xf32>) -> !ng.tensor<2x1x3x3xf32>
// CHECK: %1 = memref_cast %arg1 : memref<2x2x3x3xf32> to memref<*xf32>
// CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: call @__mlir_callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
// CHECK: call @callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_avgpoolbackprop(%arg0: !ng.tensor<2x2x2x2xf32>) -> !ng.tensor<2x2x3x3xf32> {
%0 = "ng.avgPoolBackprop"(%arg0) {forwardArgShape = [2, 2, 3, 3], includePadding = false, padAbove = [0, 0], padBelow = [0, 0], windowMovementStrides = [1, 1], windowShape = [2, 2]} : (!ng.tensor<2x2x2x2xf32>) -> !ng.tensor<2x2x3x3xf32>
"ng.return"(%0) : (!ng.tensor<2x2x3x3xf32>) -> ()
......@@ -83,7 +83,7 @@ func @simple_avgpoolbackprop(%arg0: !ng.tensor<2x2x2x2xf32>) -> !ng.tensor<2x2x3
// CHECK: %1 = memref_cast %arg1 : memref<64x3x9x6x5xf32> to memref<*xf32>
// CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: call @__mlir_callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
// CHECK: call @callback_1_input(%0, %1, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_maxpool(%arg0: !ng.tensor<64x3x7x8x10xf32>) -> !ng.tensor<64x3x9x6x5xf32> {
%0 = "ng.maxPool"(%arg0) {padAbove = [6, 4, 5], padBelow = [5, 6, 4], windowMovementStrides = [2, 3, 4], windowShape = [2, 3, 2]} : (!ng.tensor<64x3x7x8x10xf32>) -> !ng.tensor<64x3x9x6x5xf32>
"ng.return"(%0) : (!ng.tensor<64x3x9x6x5xf32>) -> ()
......@@ -98,7 +98,7 @@ func @simple_maxpool(%arg0: !ng.tensor<64x3x7x8x10xf32>) -> !ng.tensor<64x3x9x6x
// CHECK: %2 = memref_cast %arg2 : memref<2x2x5x5xf32> to memref<*xf32>
// CHECK: %[[C1:.*]] = constant 0 : i64
// CHECK: %[[C2:.*]] = constant {{[0-9]+}} : i64
// CHECK: call @__mlir_callback_2_inputs(%0, %1, %2, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, i64, i64) -> ()
// CHECK: call @callback_2_inputs(%0, %1, %2, %[[C1]], %[[C2]]) : (memref<*xf32>, memref<*xf32>, memref<*xf32>, i64, i64) -> ()
func @simple_maxpoolbackprop(%arg0: !ng.tensor<2x2x5x5xf32>, %arg1: !ng.tensor<2x2x4x3xf32>) -> !ng.tensor<2x2x5x5xf32> {
%0 = "ng.maxPoolBackprop"(%arg0, %arg1) {padAbove = [0, 0], padBelow = [0, 0], windowMovementStrides = [1, 1], windowShape = [2, 3]} : (!ng.tensor<2x2x5x5xf32>, !ng.tensor<2x2x4x3xf32>) -> !ng.tensor<2x2x5x5xf32>
"ng.return"(%0) : (!ng.tensor<2x2x5x5xf32>) -> ()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment