Commit 11d61848 authored by Nagy Mostafa's avatar Nagy Mostafa Committed by Scott Cyphers

[MLIR] Add support for Convolution Op without Padding (#3540)

* Op decl. No verification

* WIP

* WIP: Add lowerer support

* Code-gen works.

* Added padding support. Needs MLIR fix to work

* Remove Padding support for now

* Fix arrayref init. Disable tests where reshape optimization don't apply

* clean up and style apply

* Address PR feedback
parent 6a4850e5
......@@ -31,6 +31,7 @@
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
......@@ -249,8 +250,8 @@ void MLIRCompiler::build_ng_dialect_module()
dump_mlir_module("nGraph Dialect Construction");
}
// Converts nGraph shape \p ng_shape to MLIR shape \p mlir_shape.
static void get_mlir_shape(ngraph::Shape ng_shape, llvm::SmallVectorImpl<int64_t>& mlir_shape)
template <typename T>
void MLIRCompiler::get_mlir_shape(T ng_shape, llvm::SmallVectorImpl<int64_t>& mlir_shape)
{
for (auto dim : ng_shape)
{
......@@ -258,11 +259,19 @@ static void get_mlir_shape(ngraph::Shape ng_shape, llvm::SmallVectorImpl<int64_t
}
}
template <typename T>
mlir::ArrayAttr MLIRCompiler::get_shape_as_attr(T ng_shape)
{
SmallVector<int64_t, 4> mlir_shape;
get_mlir_shape(ng_shape, mlir_shape);
return m_builder->getI64ArrayAttr(mlir_shape);
}
// Converts an nGraph Tensor into an MLIR tensor type, including the conversion of the Tensor's
// element type.
mlir::Type MLIRCompiler::get_mlir_type(const descriptor::Tensor* tensor)
{
SmallVector<int64_t, 4> mlir_shape;
llvm::SmallVector<int64_t, 4> mlir_shape;
get_mlir_shape(tensor->get_shape(), mlir_shape);
return mlir::NGTensorType::get(
&m_context, get_mlir_type(tensor->get_element_type()), mlir_shape);
......@@ -592,6 +601,25 @@ namespace ngraph
{
return compiler.create_generic_op<mlir::NGNegOp>(ng_node);
}
template <>
mlir::Operation* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Convolution)
{
mlir::Operation* op = compiler.create_generic_op<mlir::NGConvolutionOp>(ng_node);
auto conv_node = static_cast<const ngraph::op::Convolution*>(ng_node);
auto conv_op = llvm::cast<mlir::NGConvolutionOp>(op);
mlir::ArrayAttr attr =
compiler.get_shape_as_attr(conv_node->get_window_movement_strides());
conv_op.setStrides(attr);
attr = compiler.get_shape_as_attr(conv_node->get_padding_below());
conv_op.setPadBelow(attr);
attr = compiler.get_shape_as_attr(conv_node->get_padding_above());
conv_op.setPadAbove(attr);
return op;
}
}
}
}
......
......@@ -135,6 +135,14 @@ namespace ngraph
/// Helper to dump MLIR module into llvm::dbgs prepended by the message \p msg.
void dump_mlir_module(const std::string msg);
/// Converts nGraph shape-like types \p ng_shape to MLIR shape \p mlir_shape.
template <typename T>
void get_mlir_shape(T ng_shape, llvm::SmallVectorImpl<int64_t>& mlir_shape);
/// Converts an ngraph shape to an I64 array attribute
template <typename T>
mlir::ArrayAttr get_shape_as_attr(T ng_shape);
private:
// Sub-graph to be compiled and executed with MLIR.
const ngraph::op::CompiledKernel* m_compiled_kernel;
......
......@@ -205,6 +205,105 @@ mlir::LogicalResult verifyOp(NGGatherOp* op)
return mlir::success();
}
template <>
mlir::LogicalResult verifyOp(NGConvolutionOp* op)
{
Type ty = op->images()->getType();
NGTensorType imagesType = ty.cast<NGTensorType>();
Type imagesEt = imagesType.getElementType();
Shape imagesShape = imagesType.getShape();
ty = op->filters()->getType();
NGTensorType filtersType = ty.cast<NGTensorType>();
Type filtersEt = filtersType.getElementType();
Shape filtersShape = filtersType.getShape();
ty = op->res()->getType();
NGTensorType resultType = ty.cast<NGTensorType>();
Type resultEt = resultType.getElementType();
Shape resultShape = resultType.getShape();
ArrayAttr strides = op->strides();
ArrayAttr padBelow = op->padBelow();
ArrayAttr padAbove = op->padAbove();
unsigned imagesRank = imagesShape.size();
unsigned filtersRank = filtersShape.size();
unsigned resultRank = resultShape.size();
unsigned imageSpatialRank = imagesRank - 2;
unsigned filtersSpatialRank = filtersRank - 2;
unsigned stridesRank = strides.size();
unsigned padBelowRank = padBelow.size();
unsigned padAboveRank = padAbove.size();
SmallVector<int64_t, 4> stridesVal, padAboveVal, padBelowVal;
// Identical filters and image element types
if (filtersEt != imagesType)
{
return op->emitOpError("Incompatible image and filters types");
}
// Verify image shape
if (imagesRank < 3)
{
return op->emitOpError("Image shape of rank below 3");
}
// Verify strides and pads shapes
if (imageSpatialRank != stridesRank || imageSpatialRank != padBelowRank ||
imageSpatialRank != padAboveRank)
{
return op->emitOpError("Image spatial rank mismatches strides and/or padding ranks");
}
if (imageSpatialRank != filtersSpatialRank)
{
return op->emitOpError("Image and filters spatial ranks mismatch");
}
// Batch size is non-zero, and identical non-zero channel depth
if (imagesShape[0] <= 0 || filtersShape[0] <= 0 || imagesShape[1] != filtersShape[1] ||
imagesShape[1] <= 0)
{
return op->emitOpError("Image and filters have invalid shapes");
}
for (auto attrs : llvm::zip(strides, padBelow, padAbove))
{
auto s = std::get<0>(attrs).cast<IntegerAttr>().getInt();
auto pb = std::get<1>(attrs).cast<IntegerAttr>().getInt();
auto pa = std::get<2>(attrs).cast<IntegerAttr>().getInt();
if (s <= 0)
{
return op->emitOpError("Window stride must be non-negative");
}
if (pb < 0 || pa < 0)
{
return op->emitOpError("Paddings must be positive");
}
stridesVal.push_back(s);
padBelowVal.push_back(pb);
padAboveVal.push_back(pa);
}
// Check output shape
if (resultRank != imagesRank || resultShape[0] != imagesShape[0] ||
resultShape[1] != filtersShape[0])
{
return op->emitOpError("Invalid result shape");
}
for (unsigned i = 0; i < resultRank - 2; i++)
{
unsigned resDim = llvm::divideCeil(padBelowVal[i] + padAboveVal[i] + imagesShape[2 + i] -
filtersShape[2 + i] + 1,
stridesVal[i]);
if (resultShape[i] != resDim)
{
return op->emitOpError("Invalid result spatial shape");
}
}
return mlir::success();
}
namespace mlir
{
#define GET_OP_CLASSES
......
......@@ -278,6 +278,35 @@ def NGGatherOp :
let verifier = [{ return verifyOp(this); }];
}
// Convolution
def NGConvolutionOp :
NG_OneResult_Op<"convolution", [NoSideEffect]>,
Arguments<(ins NG_TensorType:$images, NG_TensorType:$filters,
I64ArrayAttr:$strides,
I64ArrayAttr:$padBelow,
I64ArrayAttr:$padAbove)>
{
let summary = "Convolution of a tensor of filters over a tensor of images with padding support";
let description = [{
Convolution operation with padding and stride support. No dilation supported.
images Input image tensor. Shape is [N, C_IN, D1, ... Df]
filters Set of filters to apply. Shape is [C_OUT, C_IN, F1, ... Ff]
strides Window movement strides. Shape is [f]. Attribute.
padBelow The padding-below sizes. Shape is [f]. Attribute.
padAbove The padding-below sizes. Shape is [f]. Attribute.
Output is of shape [N, C_OUT, R1, ... Rf]
}];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
let verifier = [{ return verifyOp(this); }];
let extraClassDeclaration = [{
void setStrides(ArrayAttr& arrayAttr) { this->setAttr("strides", arrayAttr); }
void setPadBelow(ArrayAttr& arrayAttr) { this->setAttr("padBelow", arrayAttr); }
void setPadAbove(ArrayAttr& arrayAttr) { this->setAttr("padAbove", arrayAttr); }
}];
}
// Terminator Ops
def NGReturnOp : NG_Terminator_Op<"return">;
......
......@@ -28,6 +28,8 @@
#include <mlir/EDSC/Builders.h>
#include <mlir/EDSC/Helpers.h>
#include <mlir/EDSC/Intrinsics.h>
#include <mlir/IR/AffineExpr.h>
#include <mlir/IR/IntegerSet.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/Transforms/DialectConversion.h>
......@@ -751,6 +753,206 @@ namespace
return matchSuccess();
}
REWRITER(NGConvolutionOp)
{
auto convolOp = cast<NGConvolutionOp>(op);
auto loc = convolOp.getLoc();
ScopedContext scope(rewriter, loc);
// Get operands
Value* result = pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result, "Unexpected null result in Convolution Op");
Value* images = operands[0];
Value* filters = operands[1];
auto strides = convolOp.strides().getValue();
auto padBelow = convolOp.padBelow().getValue();
auto padAbove = convolOp.padBelow().getValue();
for (auto value : llvm::zip(padBelow, padAbove))
{
auto padAttr = std::get<0>(value);
NGRAPH_CHECK(padAttr.cast<IntegerAttr>().getInt() == 0,
"No support for padding in convolution op");
padAttr = std::get<1>(value);
NGRAPH_CHECK(padAttr.cast<IntegerAttr>().getInt() == 0,
"No support for padding in convolution op");
}
Type elemTy = images->getType().cast<MemRefType>().getElementType();
// Let Images shape be [N, C_IN, D_1, ... D_f]
// Let Filters shape be [C_OUT, C_IN, F_1, ... F_f]
// Output shape will be [N, C_OUT, R_1, ..R_f]
// where R_i = (AdjD_i - AdjF_i + 1) / Strides[i]
//
// AdjD_i is adjusted image spatial dimension after padding and dilation
// AdjD_i = padBelow[i] + (dilation[i] * (D_i - 1) + 1) + padAbove[i]
//
// AdjF_i is adjusted filters spatial dimension after dilation
// AdjF_i = dilation[i] * (F_i - 1) + 1
//
// If no padding, padAbove/Below[i] = 0
// If no dilation, dilation[i] is 1
//
// Generate the following (currently without padding/dilation support)
//
//
// for n : 0 -> N
// for k : 0 -> C_OUT
// for <r_1 .. r_f> : <0 .. 0> -> <R_1 ... R_f>
// //initialize result to zero
// Output[n, k, r_1, .. r_f] = 0;
//
// for n : 0 -> N
// for k : 0 -> C_OUT
// for c : 0 -> C_IN
// // iterate over output spatial shape
// for <r_1 .. r_f> : <0 .. 0> -> <R_1 ... R_f> //
// //compute image start inputs indices
// i_1 = r_1 * strides[0];
// ..
// i_f = r_f * strides[f - 1];
// // iterate over kernel spatial shape
// for <j_1 .. j_f> : <0 .. 0> -> <F_1 .. F_f>
// Output[n, k, r_1, .. r_f] +=
// Images[n, c, i_1 + j_1, .. i_f + j_f] * Filters[k, c, j_1, .. j_f]
// TODO: With padding, we need to check (using IntegerSets) whether each spatial dim in
// Images lie within paddings
// If yes, we init value to zero, else load from MemRef.
// Q: Can this be done using a map from padded tensor to unpadded one ? Will we load zero
// if OOB ?
// Create view to write into result.
MemRefView vRes(result), vImages(images), vFilters(filters);
// Indexed Values
IndexedValue iRes(result), iImages(images), iFilters(filters);
// Bounds on batch size N
ValueHandle batchLb = vImages.lb(0), batchUb = vImages.ub(0);
// Bounds on number of filters
ValueHandle numFiltersLb = vFilters.lb(0), numFiltersUb = vFilters.ub(0);
// Bound on number of channels
ValueHandle numChannelsLb = vImages.lb(1), numChannelsUb = vImages.ub(1);
// Bounds on result spatial dimensions
SmallVector<ValueHandle, 4> resSpatialLbs, resSpatialUbs;
SmallVector<ValueHandle, 4> imgSpatialLbs, imgSpatialUbs;
SmallVector<ValueHandle, 4> filtersSpatialLbs, filtersSpatialUbs;
// Spatial rank
unsigned spatialRank = vImages.rank() - 2;
// Result spatial indices and bounds
auto resSpatialIndices = makeIndexHandles(spatialRank);
auto resSpatialIndicesPtrs = makeIndexHandlePointers(resSpatialIndices);
SmallVector<int64_t, 4> resSteps, filtersSteps;
for (auto i = 0; i < spatialRank; i++)
{
// result spatial bounds and steps
resSpatialLbs.push_back(vRes.lb(i + 2));
resSpatialUbs.push_back(vRes.ub(i + 2));
resSteps.push_back(vRes.step(i + 2));
// image spatial bounds
imgSpatialLbs.push_back(vImages.lb(i + 2));
imgSpatialUbs.push_back(vImages.ub(i + 2));
}
NGRAPH_CHECK(vImages.rank() == vFilters.rank(), "Images and Filters have unequal ranks");
NGRAPH_CHECK(resSpatialLbs.size() == resSpatialUbs.size() &&
resSpatialLbs.size() == spatialRank,
"Results spatial dims mismatches input");
// Filters spatial indices and bounds
auto filtersSpatialIndices = makeIndexHandles(spatialRank);
auto filtersSpatialIndicesPtrs = makeIndexHandlePointers(filtersSpatialIndices);
for (auto i = 0; i < spatialRank; i++)
{
filtersSpatialLbs.push_back(vFilters.lb(i + 2));
filtersSpatialUbs.push_back(vFilters.ub(i + 2));
filtersSteps.push_back(vFilters.step(i + 2));
}
// Initialize output to zero
{
IndexHandle n, k, c;
auto resSpatialIndices = makeIndexHandles(spatialRank);
auto resSpatialIndicesPtrs = makeIndexHandlePointers(resSpatialIndices);
LoopBuilder(&n, batchLb, batchUb, 1)([&] {
LoopBuilder(&k, numFiltersLb, numFiltersUb, 1)([&] {
LoopNestBuilder(
resSpatialIndicesPtrs, resSpatialLbs, resSpatialUbs, resSteps)([&] {
SmallVector<IndexHandle, 4> resIndices;
// Result indices
resIndices.push_back(n);
resIndices.push_back(k);
resIndices.insert(
resIndices.end(), resSpatialIndices.begin(), resSpatialIndices.end());
ValueHandle zero = createZeroConstant(elemTy);
iRes(resIndices) = zero;
});
});
});
}
IndexHandle n, k, c;
// Convolution loop
LoopBuilder(&n, batchLb, batchUb, 1)([&] {
// Number of filters loop
LoopBuilder(&k, numFiltersLb, numFiltersUb, 1)([&] {
// Channels loop
LoopBuilder(&c, numChannelsLb, numChannelsUb, 1)([&] {
// Results loop
LoopNestBuilder(
resSpatialIndicesPtrs, resSpatialLbs, resSpatialUbs, resSteps)([&] {
// Compute image start indices
SmallVector<IndexHandle, 4> imgStartIndices;
for (auto i = 0; i < spatialRank; i++)
{
IntegerAttr iAttr = strides[i].cast<IntegerAttr>();
auto stride = intrinsics::constant_index(iAttr.getInt());
imgStartIndices.push_back(IndexHandle(resSpatialIndices[i] * stride));
}
SmallVector<IndexHandle, 4> resIndices;
// Result indices
resIndices.push_back(n);
resIndices.push_back(k);
resIndices.insert(
resIndices.end(), resSpatialIndices.begin(), resSpatialIndices.end());
// Filters spatial loop
LoopNestBuilder(filtersSpatialIndicesPtrs,
filtersSpatialLbs,
filtersSpatialUbs,
filtersSteps)([&] {
SmallVector<IndexHandle, 4> imgIndices, filtersIndices;
// Image indices
imgIndices.push_back(n);
imgIndices.push_back(c);
for (auto i = 0; i < spatialRank; i++)
{
imgIndices.push_back(
IndexHandle(imgStartIndices[i] + filtersSpatialIndices[i]));
}
// Filter indices
filtersIndices.push_back(k);
filtersIndices.push_back(c);
filtersIndices.insert(filtersIndices.end(),
filtersSpatialIndices.begin(),
filtersSpatialIndices.end());
iRes(resIndices) =
iRes(resIndices) + (iImages(imgIndices) * iFilters(filtersIndices));
});
});
});
});
});
rewriter.replaceOp(op, {result});
return matchSuccess();
}
REWRITER(NGReturnOp)
{
pass.insertDeallocs(rewriter);
......
......@@ -27,6 +27,7 @@ MLIR_OP(NGAddOp)
MLIR_OP(NGArgMaxRedOp)
MLIR_OP(NGArgMinRedOp)
MLIR_OP(NGConcatOp)
MLIR_OP(NGConvolutionOp)
MLIR_OP(NGDivOp)
MLIR_OP(NGDotOp)
MLIR_OP(NGGatherOp)
......
......@@ -9,6 +9,7 @@ MLIR_OP(ArgMax)
MLIR_OP(Divide)
MLIR_OP(Dot)
MLIR_OP(Concat)
MLIR_OP(Convolution)
MLIR_OP(Gather)
MLIR_OP(Greater)
MLIR_OP(Less)
......
......@@ -25,6 +25,7 @@
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
......@@ -480,6 +481,24 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
return true;
}
if (TI(ngraph::op::Convolution) == TI(*node))
{
// No padding for now
auto conv_node = static_cast<ngraph::op::Convolution*>(node.get());
auto pad_below = conv_node->get_padding_below();
auto pad_above = conv_node->get_padding_above();
auto data_dilation = conv_node->get_data_dilation_strides();
auto window_dilation = conv_node->get_window_dilation_strides();
auto is_zero = [](size_t s) { return s == 0; };
auto is_one = [](size_t s) { return s == 1; };
return std::all_of(pad_below.begin(), pad_below.end(), is_zero) &&
std::all_of(pad_above.begin(), pad_above.end(), is_zero) &&
std::all_of(data_dilation.begin(), data_dilation.end(), is_one) &&
std::all_of(window_dilation.begin(), window_dilation.end(), is_one);
}
return true;
}
......
......@@ -66,3 +66,36 @@ NGRAPH_TEST(${BACKEND_NAME}, convolution_outlining)
handle->call_with_validate({result}, {a, b});
EXPECT_TRUE(test::all_close_f(vector<float>{expected_result}, read_vector<float>(result)));
}
NGRAPH_TEST(${BACKEND_NAME}, convolution_simple)
{
Shape shape_a{1, 2, 2, 2};
auto A = make_shared<op::Parameter>(element::f32, shape_a);
Shape shape_b{2, 2, 1, 1};
auto B = make_shared<op::Parameter>(element::f32, shape_b);
Shape shape_r{1, 2, 2, 2};
auto conv1 = make_shared<op::Convolution>(A,
B,
Strides{1, 1},
Strides{1, 1},
CoordinateDiff{0, 0},
CoordinateDiff{0, 0},
Strides{1, 1});
auto f = make_shared<Function>(conv1, ParameterVector{A, B});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape_a);
copy_data(a, vector<float>{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f});
auto b = backend->create_tensor(element::f32, shape_b);
copy_data(b, vector<float>{3.0f, 3.0f, 3.0f, 3.0f});
auto result = backend->create_tensor(element::f32, shape_r);
vector<float> expected_result{18.0f, 24.0f, 30.0f, 36.0f, 18.0f, 24.0f, 30.0f, 36.0f};
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a, b});
EXPECT_TRUE(test::all_close_f(vector<float>{expected_result}, read_vector<float>(result)));
}
......@@ -231,7 +231,7 @@ TEST(cpu_test, mkldnn_layouts)
EXPECT_TRUE(test::all_close_f(vector<float>{expected_result}, rv));
}
TEST(cpu_test, reshape_layout_optimizations1)
TEST(cpu_test, MLIR_DISABLE_TEST(reshape_layout_optimizations1))
{
// Squeeze outermost dimension
auto make_function = []() -> std::shared_ptr<Function> {
......@@ -270,7 +270,7 @@ TEST(cpu_test, reshape_layout_optimizations1)
}
}
TEST(cpu_test, reshape_layout_optimizations2)
TEST(cpu_test, MLIR_DISABLE_TEST(reshape_layout_optimizations2))
{
// ExpandDims - inner most and internal dims
auto make_function = []() -> std::shared_ptr<Function> {
......@@ -309,7 +309,7 @@ TEST(cpu_test, reshape_layout_optimizations2)
}
}
TEST(cpu_test, reshape_layout_optimizations3)
TEST(cpu_test, MLIR_DISABLE_TEST(reshape_layout_optimizations3))
{
// Squeeze padded dimension
auto make_function = []() -> std::shared_ptr<Function> {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment