Commit 33ec9a8b authored by nmostafa's avatar nmostafa

Lowering to affine. CPU.gather_* pass

parent 67536ddf
...@@ -171,7 +171,7 @@ static mlir::LogicalResult verifyCmpOp(T* op) ...@@ -171,7 +171,7 @@ static mlir::LogicalResult verifyCmpOp(T* op)
template <> template <>
mlir::LogicalResult verifyOp(NGGatherOp* op) mlir::LogicalResult verifyOp(NGGatherOp* op)
{ {
Type ty = op->input()->getType(); Type ty = op->params()->getType();
NGTensorType inputType = ty.cast<NGTensorType>(); NGTensorType inputType = ty.cast<NGTensorType>();
ty = op->indices()->getType(); ty = op->indices()->getType();
......
...@@ -260,14 +260,14 @@ def NGAnyRedOp : NG_Axis_Reduction_Op<"any.red"> ...@@ -260,14 +260,14 @@ def NGAnyRedOp : NG_Axis_Reduction_Op<"any.red">
// Gather // Gather
def NGGatherOp : def NGGatherOp :
NG_OneResult_Op<"gather", [NoSideEffect]>, NG_OneResult_Op<"gather", [NoSideEffect]>,
Arguments<(ins NG_TensorType:$input, NG_TensorType:$indices, I64Attr:$axis)> Arguments<(ins NG_TensorType:$params, NG_TensorType:$indices, I64Attr:$axis)>
{ {
let summary = "Gather slices from input along the specified axis according to indices"; let summary = "Gather slices from params along the specified axis according to indices";
let description = [{ let description = [{
Gather slices from axis of input according to indices Gather slices from axis of params according to indices
input The tensor from which slices are gathered params The tensor from which slices are gathered
indices Index tensor: Data type must be `element::i32` or `element::i64` indices Index tensor. Data type must be `element::i32` or `element::i64`
axis Axis in input to gather axis Axis in params to gather
}]; }];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }]; let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
......
...@@ -647,14 +647,127 @@ namespace ...@@ -647,14 +647,127 @@ namespace
return matchSuccess(); return matchSuccess();
} }
REWRITER(NGReturnOp) REWRITER(NGGatherOp)
{ {
rewriter.replaceOpWithNewOp<ReturnOp>(op); auto gatherOp = cast<NGGatherOp>(op);
auto loc = gatherOp.getLoc();
ScopedContext scope(rewriter, loc);
// Get operands
Value* result = m_pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result, "Unexpected null result in GatherOp");
auto resultTy = result->getType().cast<MemRefType>();
Value* params = operands[0];
Value* indices = operands[1];
auto axis = gatherOp.axis().getSExtValue();
// Create view to write into result.
MemRefView vRes(result), vParams(params), vIndices(indices);
// Indexed Values
IndexedValue iRes(result), iParams(params), iIndices(indices);
// Construct outer loop for params dims. Exclude the axis dim.
SmallVector<ValueHandle, 4> paramsLbs, paramsUbs;
SmallVector<IndexHandle, 4> paramsIVs;
SmallVector<int64_t, 4> paramsSteps;
SmallVector<ValueHandle*, 4> paramsIVPtrs;
for (auto i = 0; i < vParams.rank(); i++)
{
// skip gather axis
if (i == axis)
continue;
paramsLbs.push_back(IndexHandle(vParams.lb(i)));
paramsUbs.push_back(IndexHandle(vParams.ub(i)));
paramsSteps.push_back(vParams.step(i));
}
NGRAPH_CHECK(
paramsLbs.size() == vParams.rank() - 1 &&
paramsUbs.size() == paramsLbs.size() &&
paramsSteps.size() == paramsLbs.size(),
"Incorrect loop nest bounds size for gather params");
paramsIVs = IndexHandle::makeIndexHandles(vParams.rank()-1);
paramsIVPtrs = IndexHandle::makeIndexHandlePointers(paramsIVs);
auto indicesLbs = vIndices.getLbs();
auto indicesUbs = vIndices.getUbs();
auto indicesSteps = vIndices.getSteps();
auto indicesIVs = IndexHandle::makeIndexHandles(vIndices.rank());
auto indicesIVPtrs = IndexHandle::makeIndexHandlePointers(indicesIVs);
SmallVector<IndexHandle, 8> paramsIndices, resIndices;
// Make sure we are going to create loops
NGRAPH_CHECK(vParams.rank() > 0, "Invalid size for indices steps");
// Let params rank : N
// Let indices rank : M
// Let axis be A
// Generate
// params loops
// for P_0: 0 -> params.dim[0]
// for P_1: 0 -> params.dim[1]
// for P_2: 0 -> params.dim[2]
// ...
// for P_(A-1):0 -> params.dim[A-1]
// for P_(A+1):0 -> params.dim[A+1]
// ...
// for P_(N-1):0 -> params.dim[N-1]
// indices loops
// for I_0:0 -> indices.dim[0]
// ...
// for I_(M-1):0 -> indices.dim[M-1]
// res[P_0, P_1, .. P_(A-1), I_0, .., I_(M-1), P_(A+1), ... P_(N-1)] =
// params[P_0, P_1, .. P_(A-1), indices[I_0, .., I_(M-1)], P_(A+1), ... P_(N-1)];
LoopNestBuilder(paramsIVPtrs, paramsLbs, paramsUbs, paramsSteps)([&] {
LoopNestBuilder(indicesIVPtrs, indicesLbs, indicesUbs, indicesSteps)([&] {
// Load axis value from indices array and cast it to Index Type
ValueHandle axisIdx = ValueHandle::create<IndexCastOp>(
(ValueHandle)iIndices(indicesIVs), rewriter.getIndexType());
// construct indices for param
// [P_0, P_1, .. P_axis-1, Indices[I0, I1, .. I_k-1], P_axis+1, P_axis+2, .. P_n-1]
for (auto i = 0, j = 0; i < vParams.rank(); i++)
{
if (i == axis)
{
paramsIndices.push_back(IndexHandle(axisIdx));
}
else
{
paramsIndices.push_back(paramsIVs[j++]);
}
}
// construct indices for result
// [P_0, P_1, .. P_axis-1, I0, I1, .. I_k-1, P_axis+1, P_axis+2, .. P_n-1]
for (auto i = 0, j = 0; i < vParams.rank() + vIndices.rank() - 1;)
{
if (i == axis && indicesIVs.size() > 0)
{
resIndices.append(indicesIVs.begin(), indicesIVs.end());
i += indicesIVs.size();
}
else
{
resIndices.push_back(paramsIVs[j++]);
i++;
}
}
// Store into result
iRes(resIndices) = iParams(paramsIndices);
});
});
rewriter.replaceOp(op, {result});
return matchSuccess(); return matchSuccess();
} }
REWRITER(NGGatherOp) REWRITER(NGReturnOp)
{ {
rewriter.replaceOpWithNewOp<ReturnOp>(op);
return matchSuccess(); return matchSuccess();
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment