Commit 8e46ff86 authored by Nagy Mostafa's avatar Nagy Mostafa Committed by Sang Ik Lee

[MLIR] Enable lowering of GroupConv in MLIR CPU backend (#4102)

* WIP

* WIP

* Refactored existing convolution

* Add Channel and num of filters bounds parameters to helper

* Works on unit-tests. v1 op gets converted and breaks

* Fixed group conv with groups in filters shape. Tests pass

* style

* add LIT tests

* Switch outer loop to affine loop

* re-org code

* PR fixes

* Revert ops.td

* PR fixes
parent c2bdccfa
......@@ -169,6 +169,36 @@ namespace
PatternRewriter& rewriter,
DialectLoweringPass& pass);
// Generates a convolution kernel that can be used to generate single or
// group convolution. It can handle filters where C_OUT dim includes
// all groups, or if groups is an additional dimension before C_OUT.
//
// For single convolution, the default variables do not
// have to be specific and will be auto-deduced from the input shapes.
//
// For group convolution, the caller has to generate the outer loop
// over the number of groups. It will also generate the bounds on the
// C_IN and C_OUT dimensions. It will pass the bounds and IV of the outer
// loop as follows:
//
// cLb/Ub : Values representing bounds on channel dim in image (C_IN)
// kLb/Ub : Values representing bounds on numFilters dim in filters (C_OUT)
// gId : Value representing induction variable for the outer loop
void lowerConvolution(Value* result,
Value* images,
Value* filters,
ArrayAttr stridesAttr,
ArrayAttr padBelowAttr,
ArrayAttr padAboveAttr,
PatternRewriter& rewriter,
DialectLoweringPass& pass,
Location loc,
Value* cLb = nullptr,
Value* cUb = nullptr,
Value* kLb = nullptr,
Value* kUb = nullptr,
Value* gId = nullptr);
template <typename OP>
void lowerPooling(Operation* op,
ArrayRef<Value*> operands,
......@@ -955,408 +985,215 @@ namespace
REWRITER(NGConvolutionOp)
{
auto convolOp = cast<NGConvolutionOp>(op);
auto loc = convolOp.getLoc();
ScopedContext scope(rewriter, loc);
// Get operands
Value* result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result, "Unexpected null result in Convolution Op");
Value* images = operands[0];
Value* filters = operands[1];
auto strides = convolOp.strides().getValue();
auto padBelow = convolOp.padBelow().getValue();
auto padAbove = convolOp.padBelow().getValue();
auto strides = convolOp.strides();
auto padBelow = convolOp.padBelow();
auto padAbove = convolOp.padBelow();
lowerConvolution(result,
images,
filters,
strides,
padBelow,
padAbove,
rewriter,
pass,
convolOp.getLoc());
Type elemTy = images->getType().cast<MemRefType>().getElementType();
rewriter.replaceOp(op, {result});
return matchSuccess();
}
// Let Images shape be [N, C_IN, D_1, ... D_f]
// Let Filters shape be [C_OUT, C_IN, F_1, ... F_f]
// Output shape will be [N, C_OUT, R_1, ..R_f]
// where R_i = (AdjD_i - AdjF_i + 1) / Strides[i]
//
// AdjD_i is adjusted image spatial dimension after padding and dilation
// AdjD_i = padBelow[i] + (dilation[i] * (D_i - 1) + 1) + padAbove[i]
//
// AdjF_i is adjusted filters spatial dimension after dilation
// AdjF_i = dilation[i] * (F_i - 1) + 1
//
// If no padding, padAbove/Below[i] = 0
// If no dilation, dilation[i] is 1
//
// Generate the following (currently without padding/dilation support)
//
//
// for n : 0 -> N
// for k : 0 -> C_OUT
// for <r_1 .. r_f> : <0 .. 0> -> <R_1 ... R_f>
// //initialize result to zero
// Output[n, k, r_1, .. r_f] = 0;
//
// for n : 0 -> N
// for k : 0 -> C_OUT
// for c : 0 -> C_IN
// // iterate over output spatial shape
// for <r_1 .. r_f> : <0 .. 0> -> <R_1 ... R_f> //
// //compute image start inputs indices
// i_1 = r_1 * strides[0];
// ..
// i_f = r_f * strides[f - 1];
// // iterate over kernel spatial shape
// for <j_1 .. j_f> : <0 .. 0> -> <F_1 .. F_f>
// Output[n, k, r_1, .. r_f] +=
// Images[n, c, i_1 + j_1, .. i_f + j_f] * Filters[k, c, j_1, .. j_f]
REWRITER(NGGroupConvOp)
{
auto gConvOp = cast<NGGroupConvOp>(op);
ScopedContext scope(rewriter, gConvOp.getLoc());
// Get operands
Value* result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result, "Unexpected null result in Convolution Op");
Value* images = operands[0];
Value* filters = operands[1];
auto strides = gConvOp.strides();
auto padBelow = gConvOp.padBelow();
auto padAbove = gConvOp.padBelow();
int groups = gConvOp.groups().getSExtValue();
NGRAPH_CHECK(groups > 0, "Invalid number of groups");
// create outer group convolution loop
// for group = 0 to groups
IndexHandle iv;
ValueHandle lb = intrinsics::constant_index(0);
ValueHandle ub = intrinsics::constant_index(groups);
ValueHandle step = intrinsics::constant_index(1);
auto imagesType = images->getType().cast<MemRefType>();
auto filtersType = filters->getType().cast<MemRefType>();
auto imagesShape = imagesType.getShape();
auto filtersShape = filtersType.getShape();
// Filters shape contains num of groups ?
bool groupsInFilters = (filtersShape.size() != imagesShape.size());
NGRAPH_CHECK(imagesType.hasStaticShape() && filtersType.hasStaticShape(),
"Dynamic shapes are not supported");
NGRAPH_CHECK(imagesShape[1] % groups == 0,
"Channel dim is not divisible by number of groups");
NGRAPH_CHECK(groupsInFilters || filtersShape[0] % groups == 0,
"Filters dim is not divisible by number of groups");
auto channelGroupSize = intrinsics::constant_index(imagesShape[1] / groups);
auto filtersGroupSize = intrinsics::constant_index(
groupsInFilters ? filtersShape[1] : filtersShape[0] / groups);
NGRAPH_CHECK(!groupsInFilters || groups == filtersShape[0]);
LoopBuilder::makeAffine(&iv, lb, ub, 1)([&] {
// lower/upper bounds on image channel dim and kernels dim
auto cLb = iv * channelGroupSize;
auto cUb = cLb + channelGroupSize;
auto kLb = iv * filtersGroupSize;
auto kUb = kLb + filtersGroupSize;
lowerConvolution(result,
images,
filters,
strides,
padBelow,
padAbove,
rewriter,
pass,
gConvOp.getLoc(),
cLb,
cUb,
kLb,
kUb,
iv);
});
rewriter.replaceOp(op, {result});
return matchSuccess();
}
REWRITER(NGReturnOp)
{
pass.insertDeallocs(rewriter);
rewriter.replaceOpWithNewOp<ReturnOp>(op);
return matchSuccess();
}
// With padding, we check (using IntegerSets) whether each spatial dim in Images lie inside
// non-padded spatial region. If true, we perform the computation:
//
// for <j_1 .. j_f> : <0 .. 0> -> <F_1 .. F_f>
// if(indices in non-padded region):
// Output[n, k, r_1, .. r_f] +=
// Images[n, c, i_1 + j_1, .. i_f + j_f] * Filters[k, c, j_1, .. j_f]
// Use callback: Pooling, MatMul, Gemm, Softmax
static void castMemRef(SmallVector<mlir::Value*, 4> inputs,
SmallVector<mlir::Value*, 4>& outputs,
PatternRewriter& rewriter,
UnrankedMemRefType type)
{
for (auto in : inputs)
{
auto out = rewriter.create<mlir::MemRefCastOp>(rewriter.getUnknownLoc(), in, type);
outputs.push_back(out);
}
}
// Create view to write into result.
MemRefView vRes(result), vImages(images), vFilters(filters);
REWRITER(NGAvgPoolOp)
{
lowerPooling<mlir::NGAvgPoolOp>(op, operands, rewriter, pass);
return matchSuccess();
}
// Indexed Values
IndexedValue iRes(result), iImages(images), iFilters(filters);
REWRITER(NGAvgPoolBackpropOp)
{
lowerPooling<mlir::NGAvgPoolBackpropOp>(op, operands, rewriter, pass);
return matchSuccess();
}
// Bounds on batch size N
ValueHandle batchLb = vImages.lb(0), batchUb = vImages.ub(0);
// Bounds on number of filters
ValueHandle numFiltersLb = vFilters.lb(0), numFiltersUb = vFilters.ub(0);
// Bound on number of channels
ValueHandle numChannelsLb = vImages.lb(1), numChannelsUb = vImages.ub(1);
// Bounds on result spatial dimensions
SmallVector<ValueHandle, 4> resSpatialLbs, resSpatialUbs;
SmallVector<ValueHandle, 4> imgSpatialLbs, imgSpatialUbs;
SmallVector<ValueHandle, 4> filtersSpatialLbs, filtersSpatialUbs;
// Spatial rank
unsigned spatialRank = vImages.rank() - 2;
REWRITER(NGMaxPoolOp)
{
lowerPooling<mlir::NGMaxPoolOp>(op, operands, rewriter, pass);
return matchSuccess();
}
// Result spatial indices and bounds
auto resSpatialIndices = makeIndexHandles(spatialRank);
auto resSpatialIndicesPtrs =
makeHandlePointers(MutableArrayRef<IndexHandle>(resSpatialIndices));
SmallVector<int64_t, 4> resSteps, filtersSteps;
SmallVector<int, 4> padBelowIntValues;
bool withPadding = false;
REWRITER(NGMaxPoolBackpropOp)
{
auto pooling = cast<NGMaxPoolBackpropOp>(op);
auto loc = pooling.getLoc();
for (auto i = 0; i < spatialRank; i++)
{
// result spatial bounds and steps
resSpatialLbs.push_back(vRes.lb(i + 2));
resSpatialUbs.push_back(vRes.ub(i + 2));
resSteps.push_back(vRes.step(i + 2));
// image spatial bounds
imgSpatialLbs.push_back(vImages.lb(i + 2));
imgSpatialUbs.push_back(vImages.ub(i + 2));
// Retrieve/generate Values for operands and result.
ScopedContext scope(rewriter, loc);
Value* src = operands[0];
Value* delta = operands[1];
ArrayRef<Attribute> windowShape = pooling.windowShape().getValue();
ArrayRef<Attribute> windowStrides = pooling.windowMovementStrides().getValue();
ArrayRef<Attribute> padBelow = pooling.padBelow().getValue();
ArrayRef<Attribute> padAbove = pooling.padAbove().getValue();
// Check if we have any padding and collect pad values
IntegerAttr iAttr = padBelow[i].cast<IntegerAttr>();
int padValue = iAttr.getInt();
if (padValue)
{
withPadding = true;
}
padBelowIntValues.push_back(padValue);
Value* result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(src && delta && result, "Unexpected null values in MaxPoolBackprop Op");
iAttr = padAbove[i].cast<IntegerAttr>();
padValue = iAttr.getInt();
if (padValue)
{
withPadding = true;
}
}
auto resultTy = result->getType().dyn_cast<MemRefType>();
auto resultShape = resultTy.getShape();
auto srcTy = src->getType().dyn_cast<MemRefType>();
auto srcShape = srcTy.getShape();
auto deltaTy = delta->getType().dyn_cast<MemRefType>();
auto deltaShape = deltaTy.getShape();
NGRAPH_CHECK(resultTy, "Unexpected non-memref result type");
NGRAPH_CHECK(srcTy, "Unexpected non-memref src type");
NGRAPH_CHECK(deltaTy, "Unexpected non-memref delta type");
NGRAPH_CHECK(vImages.rank() == vFilters.rank(), "Images and Filters have unequal ranks");
NGRAPH_CHECK(resSpatialLbs.size() == resSpatialUbs.size() &&
resSpatialLbs.size() == spatialRank,
"Results spatial dims mismatches input");
Type elemTy = resultTy.getElementType();
NGRAPH_CHECK(elemTy == srcTy.getElementType() && elemTy == deltaTy.getElementType(),
"Types mismatch in MaxPoolBackprop");
// Filters spatial indices and bounds
auto filtersSpatialIndices = makeIndexHandles(spatialRank);
auto filtersSpatialIndicesPtrs =
makeHandlePointers(MutableArrayRef<IndexHandle>(filtersSpatialIndices));
NGRAPH_CHECK((srcShape.size() == 4 && resultShape.size() == 4) ||
(srcShape.size() == 5 && resultShape.size() == 5),
"MKLDNN pooling operation is only supported for 3D and 5D tensors");
for (auto i = 0; i < spatialRank; i++)
{
filtersSpatialLbs.push_back(vFilters.lb(i + 2));
filtersSpatialUbs.push_back(vFilters.ub(i + 2));
filtersSteps.push_back(vFilters.step(i + 2));
}
auto int64Ty = rewriter.getIntegerType(64);
auto unrankedMemrefTy = UnrankedMemRefType::get(elemTy, 0);
SmallVector<mlir::Value*, 4> inputs = {src, delta, result};
SmallVector<mlir::Value*, 4> outputs;
castMemRef(inputs, outputs, rewriter, unrankedMemrefTy);
IntegerSet nonPaddedRange;
if (withPadding)
{
// Create affine expressions and IntegerSet
// IntegerSet (d0, d1, .. d_N-1)[LB_0, LB_1, .. LB_N-1, UB_0, UB_1, .. UB_N-1], where
// for each dim:
// (d_dim - padBelow[dim] - LB_dim >= 0),
// (padBelow[dim] + UB_dim - d_dim - 1 >= 0)
SmallVector<AffineExpr, 4> affineExprs;
// Bool to indicate if expr is equality or inequality
SmallVector<bool, 4> isEq;
FuncOp callBackFunc = pass.getCallDecl(
"__mlir_callback_2_inputs",
{unrankedMemrefTy, unrankedMemrefTy, unrankedMemrefTy, int64Ty, int64Ty},
{},
rewriter);
for (unsigned dim = 0; dim < spatialRank; dim++)
opAttrs attrs;
if (srcShape.size() == 4)
{
attrs.poolAttrs2d.includePaddingInAvgComputation = false;
for (auto i = 0; i < 2; i++)
{
// i_dim
auto dimExpr = rewriter.getAffineDimExpr(dim);
auto imgLbExpr = rewriter.getAffineSymbolExpr(dim);
// expr1 : i_dim - padBelow[dim] - imgLB >= 0
auto padBelowExpr = rewriter.getAffineConstantExpr(padBelowIntValues[dim]);
affineExprs.push_back(dimExpr - padBelowExpr - imgLbExpr);
isEq.push_back(false);
// expr2: padBelow[dim] + imgUB - i_dim - 1 >= 0
auto imgUbExpr = rewriter.getAffineSymbolExpr(spatialRank + dim);
auto oneExpr = rewriter.getAffineConstantExpr(1);
affineExprs.push_back(padBelowExpr + imgUbExpr - dimExpr - oneExpr);
isEq.push_back(false);
attrs.poolAttrs2d.windowShape[i] = windowShape[i].cast<IntegerAttr>().getInt();
attrs.poolAttrs2d.windowStrides[i] = windowStrides[i].cast<IntegerAttr>().getInt();
attrs.poolAttrs2d.padBelow[i] = padBelow[i].cast<IntegerAttr>().getInt();
attrs.poolAttrs2d.padAbove[i] = padAbove[i].cast<IntegerAttr>().getInt();
}
NGRAPH_CHECK(affineExprs.size() == isEq.size() && isEq.size() == 2 * spatialRank,
"Invalid number of expressions in the IntegerSet");
nonPaddedRange = IntegerSet::get(spatialRank, 2 * spatialRank, affineExprs, isEq);
}
// Initialize output to zero
else if (srcShape.size() == 5)
{
IndexHandle n, k, c;
auto resSpatialIndices = makeIndexHandles(spatialRank);
auto resSpatialIndicesPtrs =
makeHandlePointers(MutableArrayRef<IndexHandle>(resSpatialIndices));
LoopBuilder::makeAffine(&n, batchLb, batchUb, 1)([&] {
LoopBuilder::makeAffine(&k, numFiltersLb, numFiltersUb, 1)([&] {
AffineLoopNestBuilder(
resSpatialIndicesPtrs, resSpatialLbs, resSpatialUbs, resSteps)([&] {
SmallVector<IndexHandle, 4> resIndices;
// Result indices
resIndices.push_back(n);
resIndices.push_back(k);
resIndices.insert(
resIndices.end(), resSpatialIndices.begin(), resSpatialIndices.end());
ValueHandle zero = createZeroConstant(elemTy);
iRes(resIndices) = zero;
});
});
});
opAttrs attrs;
attrs.poolAttrs3d.includePaddingInAvgComputation = false;
for (auto i = 0; i < 3; i++)
{
attrs.poolAttrs3d.windowShape[i] = windowShape[i].cast<IntegerAttr>().getInt();
attrs.poolAttrs3d.windowStrides[i] = windowStrides[i].cast<IntegerAttr>().getInt();
attrs.poolAttrs3d.padBelow[i] = padBelow[i].cast<IntegerAttr>().getInt();
attrs.poolAttrs3d.padAbove[i] = padAbove[i].cast<IntegerAttr>().getInt();
}
}
IndexHandle n, k, c;
// Convolution loop
LoopBuilder::makeAffine(&n, batchLb, batchUb, 1)([&] {
// Number of filters loop
LoopBuilder::makeAffine(&k, numFiltersLb, numFiltersUb, 1)([&] {
// Channels loop
LoopBuilder::makeAffine(&c, numChannelsLb, numChannelsUb, 1)([&] {
// Results loop
AffineLoopNestBuilder(
resSpatialIndicesPtrs, resSpatialLbs, resSpatialUbs, resSteps)([&] {
// Compute image start indices
SmallVector<IndexHandle, 4> imgStartIndices;
for (auto i = 0; i < spatialRank; i++)
{
IntegerAttr iAttr = strides[i].cast<IntegerAttr>();
auto stride = intrinsics::constant_index(iAttr.getInt());
imgStartIndices.push_back(IndexHandle(resSpatialIndices[i] * stride));
}
SmallVector<IndexHandle, 4> resIndices;
// Result indices
resIndices.push_back(n);
resIndices.push_back(k);
resIndices.insert(
resIndices.end(), resSpatialIndices.begin(), resSpatialIndices.end());
// Filters spatial loop
AffineLoopNestBuilder(filtersSpatialIndicesPtrs,
filtersSpatialLbs,
filtersSpatialUbs,
filtersSteps)([&] {
SmallVector<IndexHandle, 4> imgIndices, filtersIndices;
// Image indices
// Here we compute the virtual start index into the padded image.
imgIndices.push_back(n);
imgIndices.push_back(c);
for (auto i = 0; i < spatialRank; i++)
{
imgIndices.push_back(
IndexHandle(imgStartIndices[i] + filtersSpatialIndices[i]));
}
// Filter indices
filtersIndices.push_back(k);
filtersIndices.push_back(c);
filtersIndices.insert(filtersIndices.end(),
filtersSpatialIndices.begin(),
filtersSpatialIndices.end());
if (withPadding)
{
// if args : img dims, img lbs, img ubs
SmallVector<IndexHandle, 4>::iterator it = imgIndices.begin();
std::advance(it, 2);
SmallVector<Value*, 4> affineIfArgs(it, imgIndices.end());
affineIfArgs.insert(
affineIfArgs.end(), imgSpatialLbs.begin(), imgSpatialLbs.end());
affineIfArgs.insert(
affineIfArgs.end(), imgSpatialUbs.begin(), imgSpatialUbs.end());
auto affineIfOp =
rewriter.create<AffineIfOp>(rewriter.getUnknownLoc(),
nonPaddedRange,
affineIfArgs,
/*withElseRegion=*/false);
{
auto rewriter = affineIfOp.getThenBodyBuilder();
ScopedContext scope(rewriter, loc);
// We must subtract pad below before img load, since the
// physical image is not padded
SmallVector<IndexHandle, 4> adjustedImgIndices;
adjustedImgIndices.push_back(n);
adjustedImgIndices.push_back(c);
for (auto i = 0; i < spatialRank; i++)
{
adjustedImgIndices.push_back(IndexHandle(
imgIndices[2 + i] -
intrinsics::constant_index(padBelowIntValues[i])));
}
iRes(resIndices) =
iRes(resIndices) +
(iImages(adjustedImgIndices) * iFilters(filtersIndices));
}
}
else
{
iRes(resIndices) = iRes(resIndices) +
(iImages(imgIndices) * iFilters(filtersIndices));
}
});
});
});
});
});
rewriter.replaceOp(op, {result});
return matchSuccess();
}
REWRITER(NGReturnOp)
{
pass.insertDeallocs(rewriter);
rewriter.replaceOpWithNewOp<ReturnOp>(op);
return matchSuccess();
}
// Use callback: Pooling, MatMul, Gemm, Softmax
static void castMemRef(SmallVector<mlir::Value*, 4> inputs,
SmallVector<mlir::Value*, 4>& outputs,
PatternRewriter& rewriter,
UnrankedMemRefType type)
{
for (auto in : inputs)
{
auto out = rewriter.create<mlir::MemRefCastOp>(rewriter.getUnknownLoc(), in, type);
outputs.push_back(out);
}
}
REWRITER(NGAvgPoolOp)
{
lowerPooling<mlir::NGAvgPoolOp>(op, operands, rewriter, pass);
return matchSuccess();
}
REWRITER(NGAvgPoolBackpropOp)
{
lowerPooling<mlir::NGAvgPoolBackpropOp>(op, operands, rewriter, pass);
return matchSuccess();
}
REWRITER(NGMaxPoolOp)
{
lowerPooling<mlir::NGMaxPoolOp>(op, operands, rewriter, pass);
return matchSuccess();
}
REWRITER(NGMaxPoolBackpropOp)
{
auto pooling = cast<NGMaxPoolBackpropOp>(op);
auto loc = pooling.getLoc();
// Retrieve/generate Values for operands and result.
ScopedContext scope(rewriter, loc);
Value* src = operands[0];
Value* delta = operands[1];
ArrayRef<Attribute> windowShape = pooling.windowShape().getValue();
ArrayRef<Attribute> windowStrides = pooling.windowMovementStrides().getValue();
ArrayRef<Attribute> padBelow = pooling.padBelow().getValue();
ArrayRef<Attribute> padAbove = pooling.padAbove().getValue();
Value* result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(src && delta && result, "Unexpected null values in MaxPoolBackprop Op");
auto resultTy = result->getType().dyn_cast<MemRefType>();
auto resultShape = resultTy.getShape();
auto srcTy = src->getType().dyn_cast<MemRefType>();
auto srcShape = srcTy.getShape();
auto deltaTy = delta->getType().dyn_cast<MemRefType>();
auto deltaShape = deltaTy.getShape();
NGRAPH_CHECK(resultTy, "Unexpected non-memref result type");
NGRAPH_CHECK(srcTy, "Unexpected non-memref src type");
NGRAPH_CHECK(deltaTy, "Unexpected non-memref delta type");
Type elemTy = resultTy.getElementType();
NGRAPH_CHECK(elemTy == srcTy.getElementType() && elemTy == deltaTy.getElementType(),
"Types mismatch in MaxPoolBackprop");
NGRAPH_CHECK((srcShape.size() == 4 && resultShape.size() == 4) ||
(srcShape.size() == 5 && resultShape.size() == 5),
"MKLDNN pooling operation is only supported for 3D and 5D tensors");
auto int64Ty = rewriter.getIntegerType(64);
auto unrankedMemrefTy = UnrankedMemRefType::get(elemTy, 0);
SmallVector<mlir::Value*, 4> inputs = {src, delta, result};
SmallVector<mlir::Value*, 4> outputs;
castMemRef(inputs, outputs, rewriter, unrankedMemrefTy);
FuncOp callBackFunc = pass.getCallDecl(
"__mlir_callback_2_inputs",
{unrankedMemrefTy, unrankedMemrefTy, unrankedMemrefTy, int64Ty, int64Ty},
{},
rewriter);
opAttrs attrs;
if (srcShape.size() == 4)
{
attrs.poolAttrs2d.includePaddingInAvgComputation = false;
for (auto i = 0; i < 2; i++)
{
attrs.poolAttrs2d.windowShape[i] = windowShape[i].cast<IntegerAttr>().getInt();
attrs.poolAttrs2d.windowStrides[i] = windowStrides[i].cast<IntegerAttr>().getInt();
attrs.poolAttrs2d.padBelow[i] = padBelow[i].cast<IntegerAttr>().getInt();
attrs.poolAttrs2d.padAbove[i] = padAbove[i].cast<IntegerAttr>().getInt();
}
}
else if (srcShape.size() == 5)
{
opAttrs attrs;
attrs.poolAttrs3d.includePaddingInAvgComputation = false;
for (auto i = 0; i < 3; i++)
{
attrs.poolAttrs3d.windowShape[i] = windowShape[i].cast<IntegerAttr>().getInt();
attrs.poolAttrs3d.windowStrides[i] = windowStrides[i].cast<IntegerAttr>().getInt();
attrs.poolAttrs3d.padBelow[i] = padBelow[i].cast<IntegerAttr>().getInt();
attrs.poolAttrs3d.padAbove[i] = padAbove[i].cast<IntegerAttr>().getInt();
}
}
auto index = pass.insertAttrs(attrs);
auto attrsIndexArg =
rewriter.create<mlir::ConstantIntOp>(rewriter.getUnknownLoc(), index, 64);
auto opTypeArg = rewriter.create<mlir::ConstantIntOp>(
rewriter.getUnknownLoc(), static_cast<int64_t>(OpType::MAXPOOLBACKPROP), 64);
SmallVector<mlir::Value*, 4> args = {
outputs[0], outputs[1], outputs[2], attrsIndexArg, opTypeArg};
auto index = pass.insertAttrs(attrs);
auto attrsIndexArg =
rewriter.create<mlir::ConstantIntOp>(rewriter.getUnknownLoc(), index, 64);
auto opTypeArg = rewriter.create<mlir::ConstantIntOp>(
rewriter.getUnknownLoc(), static_cast<int64_t>(OpType::MAXPOOLBACKPROP), 64);
SmallVector<mlir::Value*, 4> args = {
outputs[0], outputs[1], outputs[2], attrsIndexArg, opTypeArg};
rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args);
rewriter.replaceOp(op, result);
......@@ -1612,6 +1449,366 @@ namespace
#undef REWRITER
/// End of pattern matchers
void lowerConvolution(Value* result,
Value* images,
Value* filters,
ArrayAttr stridesAttr,
ArrayAttr padBelowAttr,
ArrayAttr padAboveAttr,
PatternRewriter& rewriter,
DialectLoweringPass& pass,
Location loc,
Value* cLb,
Value* cUb,
Value* kLb,
Value* kUb,
Value* gId)
{
// Let Images shape be [N, C_IN, D_1, ... D_f]
// Let Filters shape be [C_OUT, C_IN, F_1, ... F_f]
// (or [GROUPS, C_OUT, C_IN, F_1, ... F_f] in case of
// group convolution with groups in filters shape)
// Output shape will be [N, C_OUT, R_1, ..R_f]
// where R_i = (AdjD_i - AdjF_i + 1) / Strides[i]
//
// AdjD_i is adjusted image spatial dimension after padding and dilation
// AdjD_i = padBelow[i] + (dilation[i] * (D_i - 1) + 1) + padAbove[i]
//
// AdjF_i is adjusted filters spatial dimension after dilation
// AdjF_i = dilation[i] * (F_i - 1) + 1
//
// If no padding, padAbove/Below[i] = 0
// If no dilation, dilation[i] is 1
//
// Generate the following (currently without padding/dilation support)
//
//
// for n : 0 -> N
// for k : 0 -> C_OUT
// for <r_1 .. r_f> : <0 .. 0> -> <R_1 ... R_f>
// //initialize result to zero
// Output[n, k, r_1, .. r_f] = 0;
//
// for n : 0 -> N
// for k : 0 -> C_OUT
// for c : 0 -> C_IN
// // iterate over output spatial shape
// for <r_1 .. r_f> : <0 .. 0> -> <R_1 ... R_f> //
// //compute image start inputs indices
// i_1 = r_1 * strides[0];
// ..
// i_f = r_f * strides[f - 1];
// // iterate over kernel spatial shape
// for <j_1 .. j_f> : <0 .. 0> -> <F_1 .. F_f>
// Output[n, k, r_1, .. r_f] +=
// Images[n, c, i_1 + j_1, .. i_f + j_f] * Filters[k, c, j_1, .. j_f]
// With padding, we check (using IntegerSets) whether each spatial dim in Images lie inside
// non-padded spatial region. If true, we perform the computation:
//
// for <j_1 .. j_f> : <0 .. 0> -> <F_1 .. F_f>
// if(indices in non-padded region):
// Output[n, k, r_1, .. r_f] +=
// Images[n, c, i_1 + j_1, .. i_f + j_f] * Filters[k, c, j_1, .. j_f]
ScopedContext scope(rewriter, loc);
auto strides = stridesAttr.getValue();
auto padBelow = padBelowAttr.getValue();
auto padAbove = padBelowAttr.getValue();
Type elemTy = images->getType().cast<MemRefType>().getElementType();
// Create views
MemRefView vRes(result), vImages(images), vFilters(filters);
// Create indexed Values
IndexedValue iRes(result), iImages(images), iFilters(filters);
// Bounds on batch size N
ValueHandle batchLb = vImages.lb(0), batchUb = vImages.ub(0);
// Bounds on spatial dimensions
SmallVector<ValueHandle, 4> resSpatialLbs, resSpatialUbs;
SmallVector<ValueHandle, 4> imgSpatialLbs, imgSpatialUbs;
SmallVector<ValueHandle, 4> filtersSpatialLbs, filtersSpatialUbs;
// Spatial rank
unsigned spatialRank = vImages.rank() - 2;
// Result spatial indices and bounds
auto resSpatialIndices = makeIndexHandles(spatialRank);
auto resSpatialIndicesPtrs =
makeHandlePointers(MutableArrayRef<IndexHandle>(resSpatialIndices));
SmallVector<int64_t, 4> resSteps, filtersSteps;
SmallVector<int, 4> padBelowIntValues;
bool withPadding = false;
// Do we have an extra dim for groups or is it folded in numFilters ?
bool groupsInFilters = (vImages.rank() != vFilters.rank());
bool groupConvolution = (kLb != nullptr);
// Number of groups can be in filters shape only with group convolution
NGRAPH_CHECK(!groupsInFilters ||
(kLb != nullptr && kUb != nullptr && cLb != nullptr && cUb != nullptr));
// Bounds on number of filters
ValueHandle numFiltersLb(rewriter.getIndexType());
ValueHandle numFiltersUb(rewriter.getIndexType());
if (groupConvolution)
{
if (groupsInFilters)
{
// use entire dim size if groups are out of the num filters dim
numFiltersLb = vFilters.lb(1);
numFiltersUb = vFilters.ub(1);
}
else
{
// use split dim within bounds generated in outer loop
numFiltersLb = ValueHandle(kLb);
numFiltersUb = ValueHandle(kUb);
}
}
else
{
numFiltersLb = vFilters.lb(0);
numFiltersUb = vFilters.ub(0);
}
// determine where spatial index starts in filters
int filtersSpatialIdx = 2;
const int imgSpatialIdx = 2;
if (groupConvolution && groupsInFilters)
{
filtersSpatialIdx = 3;
}
// Bounds on number of channels
ValueHandle numChannelsLb = (cLb == nullptr) ? vImages.lb(1) : ValueHandle(cLb);
ValueHandle numChannelsUb = (cUb == nullptr) ? vImages.ub(1) : ValueHandle(cUb);
for (auto i = 0; i < spatialRank; i++)
{
// result spatial bounds and steps
resSpatialLbs.push_back(vRes.lb(imgSpatialIdx + i));
resSpatialUbs.push_back(vRes.ub(imgSpatialIdx + i));
resSteps.push_back(vRes.step(imgSpatialIdx + i));
// image spatial bounds
imgSpatialLbs.push_back(vImages.lb(imgSpatialIdx + i));
imgSpatialUbs.push_back(vImages.ub(imgSpatialIdx + i));
// Check if we have any padding and collect pad values
IntegerAttr iAttr = padBelow[i].cast<IntegerAttr>();
int padValue = iAttr.getInt();
if (padValue)
{
withPadding = true;
}
padBelowIntValues.push_back(padValue);
iAttr = padAbove[i].cast<IntegerAttr>();
padValue = iAttr.getInt();
if (padValue)
{
withPadding = true;
}
}
NGRAPH_CHECK((groupConvolution && groupsInFilters) || (vImages.rank() == vFilters.rank()),
"Images and Filters have unequal ranks");
NGRAPH_CHECK(resSpatialLbs.size() == resSpatialUbs.size() &&
resSpatialLbs.size() == spatialRank,
"Results spatial dims mismatches input");
// Filters spatial indices and bounds
auto filtersSpatialIndices = makeIndexHandles(spatialRank);
auto filtersSpatialIndicesPtrs =
makeHandlePointers(MutableArrayRef<IndexHandle>(filtersSpatialIndices));
for (auto i = 0; i < spatialRank; i++)
{
filtersSpatialLbs.push_back(vFilters.lb(filtersSpatialIdx + i));
filtersSpatialUbs.push_back(vFilters.ub(filtersSpatialIdx + i));
filtersSteps.push_back(vFilters.step(filtersSpatialIdx + i));
}
IntegerSet nonPaddedRange;
if (withPadding)
{
// Create affine expressions and IntegerSet
// IntegerSet (d0, d1, .. d_N-1)[LB_0, LB_1, .. LB_N-1, UB_0, UB_1, .. UB_N-1], where
// for each dim:
// (d_dim - padBelow[dim] - LB_dim >= 0),
// (padBelow[dim] + UB_dim - d_dim - 1 >= 0)
SmallVector<AffineExpr, 4> affineExprs;
// Bool to indicate if expr is equality or inequality
SmallVector<bool, 4> isEq;
for (unsigned dim = 0; dim < spatialRank; dim++)
{
// i_dim
auto dimExpr = rewriter.getAffineDimExpr(dim);
auto imgLbExpr = rewriter.getAffineSymbolExpr(dim);
// expr1 : i_dim - padBelow[dim] - imgLB >= 0
auto padBelowExpr = rewriter.getAffineConstantExpr(padBelowIntValues[dim]);
affineExprs.push_back(dimExpr - padBelowExpr - imgLbExpr);
isEq.push_back(false);
// expr2: padBelow[dim] + imgUB - i_dim - 1 >= 0
auto imgUbExpr = rewriter.getAffineSymbolExpr(spatialRank + dim);
auto oneExpr = rewriter.getAffineConstantExpr(1);
affineExprs.push_back(padBelowExpr + imgUbExpr - dimExpr - oneExpr);
isEq.push_back(false);
}
NGRAPH_CHECK(affineExprs.size() == isEq.size() && isEq.size() == 2 * spatialRank,
"Invalid number of expressions in the IntegerSet");
nonPaddedRange = IntegerSet::get(spatialRank, 2 * spatialRank, affineExprs, isEq);
}
// Initialize output to zero
{
IndexHandle n, k, c;
auto resSpatialIndices = makeIndexHandles(spatialRank);
auto resSpatialIndicesPtrs =
makeHandlePointers(MutableArrayRef<IndexHandle>(resSpatialIndices));
LoopBuilder::makeAffine(&n, batchLb, batchUb, 1)([&] {
LoopBuilder::makeAffine(&k, numFiltersLb, numFiltersUb, 1)([&] {
AffineLoopNestBuilder(
resSpatialIndicesPtrs, resSpatialLbs, resSpatialUbs, resSteps)([&] {
SmallVector<IndexHandle, 4> resIndices;
// Result indices
resIndices.push_back(n);
if (groupConvolution && groupsInFilters)
{
// compute global C_OUT from gID and k
// gId * C_OUT (num of filters) + k
resIndices.push_back(IndexHandle(ValueHandle(gId) * numFiltersUb + k));
}
else
{
resIndices.push_back(k);
}
resIndices.insert(
resIndices.end(), resSpatialIndices.begin(), resSpatialIndices.end());
ValueHandle zero = createZeroConstant(elemTy);
iRes(resIndices) = zero;
});
});
});
}
IndexHandle n, k, c;
// Convolution loop
LoopBuilder::makeAffine(&n, batchLb, batchUb, 1)([&] {
// Number of filters loop
LoopBuilder::makeAffine(&k, numFiltersLb, numFiltersUb, 1)([&] {
// Channels loop
LoopBuilder::makeAffine(&c, numChannelsLb, numChannelsUb, 1)([&] {
// Results loop
AffineLoopNestBuilder(
resSpatialIndicesPtrs, resSpatialLbs, resSpatialUbs, resSteps)([&] {
// Compute image start indices
SmallVector<IndexHandle, 4> imgStartIndices;
for (auto i = 0; i < spatialRank; i++)
{
IntegerAttr iAttr = strides[i].cast<IntegerAttr>();
auto stride = intrinsics::constant_index(iAttr.getInt());
imgStartIndices.push_back(IndexHandle(resSpatialIndices[i] * stride));
}
SmallVector<IndexHandle, 4> resIndices;
// Result indices
resIndices.push_back(n);
if (groupConvolution && groupsInFilters)
{
// gId * C_OUT (num of filters) + k
resIndices.push_back(IndexHandle(ValueHandle(gId) * numFiltersUb + k));
}
else
{
resIndices.push_back(k);
}
resIndices.insert(
resIndices.end(), resSpatialIndices.begin(), resSpatialIndices.end());
// Filters spatial loop
AffineLoopNestBuilder(filtersSpatialIndicesPtrs,
filtersSpatialLbs,
filtersSpatialUbs,
filtersSteps)([&] {
SmallVector<IndexHandle, 4> imgIndices, filtersIndices;
// Image indices
// Here we compute the virtual start index into the padded image.
imgIndices.push_back(n);
imgIndices.push_back(c);
for (auto i = 0; i < spatialRank; i++)
{
imgIndices.push_back(
IndexHandle(imgStartIndices[i] + filtersSpatialIndices[i]));
}
// Filter indices
// If we are doing group convolution and filters shape dim0
// holds the number of groups, we need to use group id as the first
// index
if (groupConvolution && groupsInFilters)
{
filtersIndices.push_back(IndexHandle(gId));
}
filtersIndices.push_back(k);
// subtract lower bound of channel
// if we are doing group convolution this bound will advance based
// on the group id. For the filters, it should always start from 0
filtersIndices.push_back(IndexHandle(c - numChannelsLb));
filtersIndices.insert(filtersIndices.end(),
filtersSpatialIndices.begin(),
filtersSpatialIndices.end());
if (withPadding)
{
// if args : img dims, img lbs, img ubs
SmallVector<IndexHandle, 4>::iterator it = imgIndices.begin();
std::advance(it, 2);
SmallVector<Value*, 4> affineIfArgs(it, imgIndices.end());
affineIfArgs.insert(
affineIfArgs.end(), imgSpatialLbs.begin(), imgSpatialLbs.end());
affineIfArgs.insert(
affineIfArgs.end(), imgSpatialUbs.begin(), imgSpatialUbs.end());
auto affineIfOp =
rewriter.create<AffineIfOp>(rewriter.getUnknownLoc(),
nonPaddedRange,
affineIfArgs,
/*withElseRegion=*/false);
{
auto rewriter = affineIfOp.getThenBodyBuilder();
ScopedContext scope(rewriter, loc);
// We must subtract pad below before img load, since the
// physical image is not padded
SmallVector<IndexHandle, 4> adjustedImgIndices;
adjustedImgIndices.push_back(n);
adjustedImgIndices.push_back(c);
for (auto i = 0; i < spatialRank; i++)
{
adjustedImgIndices.push_back(IndexHandle(
imgIndices[2 + i] -
intrinsics::constant_index(padBelowIntValues[i])));
}
iRes(resIndices) =
iRes(resIndices) +
(iImages(adjustedImgIndices) * iFilters(filtersIndices));
}
}
else
{
iRes(resIndices) = iRes(resIndices) +
(iImages(imgIndices) * iFilters(filtersIndices));
}
});
});
});
});
});
}
template <typename OP>
void lowerUnaryElementwise(Operation* op,
ArrayRef<Value*> operands,
......
......@@ -36,6 +36,7 @@ MLIR_OP(NGDotOp , false )
MLIR_OP(NGGatherOp , false )
MLIR_OP(NGGemmOp , false )
MLIR_OP(NGGreaterOp , true )
MLIR_OP(NGGroupConvOp , false )
MLIR_OP(NGLessOp , true )
MLIR_OP(NGGreaterEqOp , true )
MLIR_OP(NGLessEqOp , true )
......
......@@ -740,6 +740,7 @@ def NGGroupConvOp :
void setPadAbove(const ArrayAttr& attr) { this->setAttr("padAbove", attr); }
void setPadBelow(const ArrayAttr& attr) { this->setAttr("padBelow", attr); }
void setPadType(const Attribute& attr) { this->setAttr("padType", attr); }
void setGroups(const Attribute& attr) { this->setAttr("groups", attr); }
}];
}
......
......@@ -15,6 +15,7 @@ MLIR_OP(Convolution)
MLIR_OP(Gather)
MLIR_OP(Gemm)
MLIR_OP(Greater)
MLIR_OP(GroupConvolution)
MLIR_OP(Less)
MLIR_OP(GreaterEq)
MLIR_OP(LessEq)
......
......@@ -112,7 +112,8 @@ namespace
/// Converts an ngraph shape to an I64 array attribute
template <typename T>
mlir::ArrayAttr getShapeAsAttr(T ngShape);
/// Returns the builder
mlir::OpBuilder& getBuilder() { return m_builder; }
/// Return the real input node corresponding to the fake node
ngraph::Node* getOriginArg(ngraph::Node* node) const;
......@@ -452,6 +453,27 @@ mlir::Operation* NgDialectConversionPass::COMPILE_OP_DECL(ngraph::op::Convolutio
attr = NgDialectObj.getShapeAsAttr(convNode->get_padding_above());
convOp.setPadAbove(attr);
return op;
}
template <>
mlir::Operation* NgDialectConversionPass::COMPILE_OP_DECL(ngraph::op::GroupConvolution)
{
mlir::Operation* op = NgDialectObj.createGenericOp<mlir::NGGroupConvOp>(ngNode);
auto gConvNode = static_cast<const ngraph::op::GroupConvolution*>(ngNode);
auto gConvOp = llvm::cast<mlir::NGGroupConvOp>(op);
mlir::ArrayAttr attr = NgDialectObj.getShapeAsAttr(gConvNode->get_window_movement_strides());
gConvOp.setStrides(attr);
attr = NgDialectObj.getShapeAsAttr(gConvNode->get_padding_below());
gConvOp.setPadBelow(attr);
attr = NgDialectObj.getShapeAsAttr(gConvNode->get_padding_above());
gConvOp.setPadAbove(attr);
gConvOp.setGroups(NgDialectObj.getBuilder().getI64IntegerAttr(gConvNode->get_groups()));
return op;
}
......@@ -589,7 +611,6 @@ mlir::Operation* NgDialectConversionPass::COMPILE_OP_DECL(ngraph::op::Softmax)
softmaxOp.setAxes(attr);
return op;
}
template <typename Op>
mlir::Operation* NgDialectConversionPass::createGenericOp(const ngraph::Node* ngNode, int inNum)
{
......
......@@ -150,7 +150,6 @@ func @simple_dot(%arg0: !ng.tensor<16x8xf32>, %arg1: !ng.tensor<8x32xf32>) -> !n
// -----
// std.view
// CHECK-DAG: #[[MAP0:[a-zA-Z0-9]+]] = (d0, d1) -> (d0 * 2 + d1)
// CHECK: %[[T1:[0-9]+]] = alloc() : memref<24xi8>
// CHECK-NEXT: %[[T2:[0-9]+]] = std.view %[[T1]][][] : memref<24xi8> to memref<3x2xf32, #[[MAP0]]>
......@@ -165,3 +164,82 @@ func @add(%arg0: !ng.tensor<3x2xf32>, %arg1: !ng.tensor<3x2xf32>) -> !ng.tensor<
%3 = "ng.add"(%2, %2) : (!ng.tensor<3x2xf32>, !ng.tensor<3x2xf32>) -> !ng.tensor<3x2xf32>
"ng.return"(%3) : (!ng.tensor<3x2xf32>) -> ()
}
// -----
// Convolution
// CHECK-LABEL: func @convolution
// Initialization loops
// CHECK: affine.for
// CHECK-NEXT: affine.for
// CHECK-NEXT: affine.for
// CHECK-NEXT: affine.for
// CHECK: affine.store
// Convolution loops
// CHECK: affine.for %[[a3:.*]] = 0 to 1
// CHECK: affine.for %[[a4:.*]] = 0 to 2
// CHECK: affine.for %[[a5:.*]] = 0 to 2
// CHECK: affine.for %[[a6:.*]] = 0 to 2
// CHECK: affine.for %[[a7:.*]] = 0 to 2
// CHECK: affine.for %[[a8:.*]] = 0 to 1
// CHECK: affine.for %[[a9:.*]] = 0 to 1
// CHECK: affine.load %{{.*}}[%[[a4]], %{{.*}}, %[[a8]], %[[a9]]] : memref<2x2x1x1xf32>
// CHECK: affine.load %{{.*}}[%[[a3]], %[[a5]], %{{.*}}, {{.*}}] : memref<1x2x2x2xf32>
// CHECK-NEXT: mulf
// CHECK-NEXT: affine.load %{{.*}}[%[[a3]], %[[a4]], %[[a6]], %[[a7]]] : memref<1x2x2x2xf32>
// CHECK-NEXT: %[[v4:.*]] = addf
// CHECK-NEXT: affine.store %[[v4]], %{{.*}}[%[[a3]], %[[a4]], %[[a6]], %[[a7]]] : memref<1x2x2x2xf32>
func @convolution(%arg0: !ng.tensor<1x2x2x2xf32>, %arg1: !ng.tensor<2x2x1x1xf32>) -> !ng.tensor<1x2x2x2xf32> {
%0 = "ng.convolution"(%arg0, %arg1) {padAbove = [0, 0], padBelow = [0, 0], strides = [1, 1]} : (!ng.tensor<1x2x2x2xf32>, !ng.tensor<2x2x1x1xf32>) -> !ng.tensor<1x2x2x2xf32>
"ng.return"(%0) : (!ng.tensor<1x2x2x2xf32>) -> ()
}
// -----
//
// Group Convolution
// CHECK-DAG: #[[M0:.*]] = (d0) -> (d0 * 2)
// CHECK-DAG: #[[M1:.*]] = (d0) -> (d0 * 2 + 2)
// CHECK-DAG: #[[M2:.*]] = (d0) -> (d0)
// CHECK-DAG: #[[M3:.*]] = (d0) -> (d0 + 1)
// CHECK-DAG: #[[M8:.*]] = (d0, d1) -> (d0 + d1)
// CHECK-DAG: #[[M9:.*]] = (d0, d1) -> (d0 - d1 * 2)
// CHECK-LABEL: func @groupConv
//
// Outer groups loops
// CHECK: affine.for %[[gid:.*]] = 0 to 2
// CHECK: %[[v0:.*]] = affine.apply #[[M0]](%[[gid]])
// CHECK: %[[v1:.*]] = affine.apply #[[M1]](%[[gid]])
// CHECK: %[[v2:.*]] = affine.apply #[[M2]](%[[gid]])
// CHECK: %[[v3:.*]] = affine.apply #[[M3]](%[[gid]])
//
// Initialization loops
// CHECK: affine.for
// CHECK-NEXT: affine.for
// CHECK-NEXT: affine.for
// CHECK-NEXT: affine.for
// CHECK: %[[cst:.*]] = constant 0
// CHECK: affine.store %[[cst]]
//
// Convolution loops
// CHECK: affine.for %[[a4:.*]] = 0 to 1
// CHECK: affine.for %[[a5:.*]] = #[[M2]](%[[v2]]) to #[[M2]](%[[v3]])
// CHECK: affine.for %[[a6:.*]] = #[[M2]](%[[v0]]) to #[[M2]](%[[v1]])
// CHECK: affine.for %[[a7:.*]] = 0 to 2
// CHECK: affine.for %[[a8:.*]] = 0 to 2
// CHECK: affine.for %[[a9:.*]] = 0 to 1
// CHECK: affine.for %[[a10:.*]] = 0 to 1
// CHECK: %[[v6:.*]] = affine.apply #[[M8]](%[[a7]], %[[a9]])
// CHECK: %[[v7:.*]] = affine.apply #[[M8]](%[[a8]], %[[a10]])
// CHECK: %[[v8:.*]] = affine.apply #[[M9]](%[[a6]], %[[a3]])
// CHECK: affine.load %{{.*}}[%[[a5]], %[[v8]], %[[a9]], %[[a10]]] : memref<2x2x1x1xf32>
// CHECK: affine.load %{{.*}}[%[[a4]], %[[a6]], %[[v6]], %[[v7]]] : memref<1x4x2x2xf32>
// CHECK-NEXT: mulf
// CHECK-NEXT: affine.load %{{.*}}[%[[a4]], %[[a5]], %[[a7]], %[[a8]]] : memref<1x2x2x2xf32>
// CHECK-NEXT: %[[v4:.*]] = addf
// CHECK-NEXT: affine.store %[[v4]], %{{.*}}[%[[a4]], %[[a5]], %[[a7]], %[[a8]]] : memref<1x2x2x2xf32>
func @groupConv(%arg0: !ng.tensor<1x4x2x2xf32>, %arg1: !ng.tensor<2x2x1x1xf32>) -> !ng.tensor<1x2x2x2xf32> {
%0 = "ng.groupConv"(%arg0, %arg1) {groups = 2 : i64, padAbove = [0, 0], padBelow = [0, 0], strides = [1, 1]} : (!ng.tensor<1x4x2x2xf32>, !ng.tensor<2x2x1x1xf32>) -> !ng.tensor<1x2x2x2xf32>
"ng.return"(%0) : (!ng.tensor<1x2x2x2xf32>) -> ()
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment