Commit 33ec9a8b authored by nmostafa's avatar nmostafa

Lowering to affine. CPU.gather_* pass

parent 67536ddf
......@@ -171,7 +171,7 @@ static mlir::LogicalResult verifyCmpOp(T* op)
template <>
mlir::LogicalResult verifyOp(NGGatherOp* op)
{
Type ty = op->input()->getType();
Type ty = op->params()->getType();
NGTensorType inputType = ty.cast<NGTensorType>();
ty = op->indices()->getType();
......
......@@ -260,14 +260,14 @@ def NGAnyRedOp : NG_Axis_Reduction_Op<"any.red">
// Gather
def NGGatherOp :
NG_OneResult_Op<"gather", [NoSideEffect]>,
Arguments<(ins NG_TensorType:$input, NG_TensorType:$indices, I64Attr:$axis)>
Arguments<(ins NG_TensorType:$params, NG_TensorType:$indices, I64Attr:$axis)>
{
let summary = "Gather slices from input along the specified axis according to indices";
let summary = "Gather slices from params along the specified axis according to indices";
let description = [{
Gather slices from axis of input according to indices
input The tensor from which slices are gathered
indices Index tensor: Data type must be `element::i32` or `element::i64`
axis Axis in input to gather
Gather slices from axis of params according to indices
params The tensor from which slices are gathered
indices Index tensor. Data type must be `element::i32` or `element::i64`
axis Axis in params to gather
}];
let parser = [{ NGRAPH_CHECK(false, "No parser support"); return mlir::failure(); }];
......
......@@ -647,14 +647,127 @@ namespace
return matchSuccess();
}
REWRITER(NGReturnOp)
REWRITER(NGGatherOp)
{
rewriter.replaceOpWithNewOp<ReturnOp>(op);
auto gatherOp = cast<NGGatherOp>(op);
auto loc = gatherOp.getLoc();
ScopedContext scope(rewriter, loc);
// Get operands
Value* result = m_pass.buildOutputDefs(op, rewriter)[0];
NGRAPH_CHECK(result, "Unexpected null result in GatherOp");
auto resultTy = result->getType().cast<MemRefType>();
Value* params = operands[0];
Value* indices = operands[1];
auto axis = gatherOp.axis().getSExtValue();
// Create view to write into result.
MemRefView vRes(result), vParams(params), vIndices(indices);
// Indexed Values
IndexedValue iRes(result), iParams(params), iIndices(indices);
// Construct outer loop for params dims. Exclude the axis dim.
SmallVector<ValueHandle, 4> paramsLbs, paramsUbs;
SmallVector<IndexHandle, 4> paramsIVs;
SmallVector<int64_t, 4> paramsSteps;
SmallVector<ValueHandle*, 4> paramsIVPtrs;
for (auto i = 0; i < vParams.rank(); i++)
{
// skip gather axis
if (i == axis)
continue;
paramsLbs.push_back(IndexHandle(vParams.lb(i)));
paramsUbs.push_back(IndexHandle(vParams.ub(i)));
paramsSteps.push_back(vParams.step(i));
}
NGRAPH_CHECK(
paramsLbs.size() == vParams.rank() - 1 &&
paramsUbs.size() == paramsLbs.size() &&
paramsSteps.size() == paramsLbs.size(),
"Incorrect loop nest bounds size for gather params");
paramsIVs = IndexHandle::makeIndexHandles(vParams.rank()-1);
paramsIVPtrs = IndexHandle::makeIndexHandlePointers(paramsIVs);
auto indicesLbs = vIndices.getLbs();
auto indicesUbs = vIndices.getUbs();
auto indicesSteps = vIndices.getSteps();
auto indicesIVs = IndexHandle::makeIndexHandles(vIndices.rank());
auto indicesIVPtrs = IndexHandle::makeIndexHandlePointers(indicesIVs);
SmallVector<IndexHandle, 8> paramsIndices, resIndices;
// Make sure we are going to create loops
NGRAPH_CHECK(vParams.rank() > 0, "Invalid size for indices steps");
// Let params rank : N
// Let indices rank : M
// Let axis be A
// Generate
// params loops
// for P_0: 0 -> params.dim[0]
// for P_1: 0 -> params.dim[1]
// for P_2: 0 -> params.dim[2]
// ...
// for P_(A-1):0 -> params.dim[A-1]
// for P_(A+1):0 -> params.dim[A+1]
// ...
// for P_(N-1):0 -> params.dim[N-1]
// indices loops
// for I_0:0 -> indices.dim[0]
// ...
// for I_(M-1):0 -> indices.dim[M-1]
// res[P_0, P_1, .. P_(A-1), I_0, .., I_(M-1), P_(A+1), ... P_(N-1)] =
// params[P_0, P_1, .. P_(A-1), indices[I_0, .., I_(M-1)], P_(A+1), ... P_(N-1)];
LoopNestBuilder(paramsIVPtrs, paramsLbs, paramsUbs, paramsSteps)([&] {
LoopNestBuilder(indicesIVPtrs, indicesLbs, indicesUbs, indicesSteps)([&] {
// Load axis value from indices array and cast it to Index Type
ValueHandle axisIdx = ValueHandle::create<IndexCastOp>(
(ValueHandle)iIndices(indicesIVs), rewriter.getIndexType());
// construct indices for param
// [P_0, P_1, .. P_axis-1, Indices[I0, I1, .. I_k-1], P_axis+1, P_axis+2, .. P_n-1]
for (auto i = 0, j = 0; i < vParams.rank(); i++)
{
if (i == axis)
{
paramsIndices.push_back(IndexHandle(axisIdx));
}
else
{
paramsIndices.push_back(paramsIVs[j++]);
}
}
// construct indices for result
// [P_0, P_1, .. P_axis-1, I0, I1, .. I_k-1, P_axis+1, P_axis+2, .. P_n-1]
for (auto i = 0, j = 0; i < vParams.rank() + vIndices.rank() - 1;)
{
if (i == axis && indicesIVs.size() > 0)
{
resIndices.append(indicesIVs.begin(), indicesIVs.end());
i += indicesIVs.size();
}
else
{
resIndices.push_back(paramsIVs[j++]);
i++;
}
}
// Store into result
iRes(resIndices) = iParams(paramsIndices);
});
});
rewriter.replaceOp(op, {result});
return matchSuccess();
}
REWRITER(NGGatherOp)
REWRITER(NGReturnOp)
{
rewriter.replaceOpWithNewOp<ReturnOp>(op);
return matchSuccess();
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment