Commit d0b97e73 authored by Diego Caballero's avatar Diego Caballero Committed by Scott Cyphers

[MLIR] Fix issue with non-affine loads/stores in argmin/armax/gather ops (#3351)

parent 9f928d92
......@@ -43,6 +43,8 @@ namespace
using namespace mlir::edsc::op;
using namespace ngraph::runtime;
using namespace ngraph::runtime::ngmlir;
// Index notation to generate standard (i.e., non-affine) loads and stores.
using StdIndexedValue = TemplatedIndexedValue<intrinsics::std_load, intrinsics::std_store>;
class DialectLoweringPass;
......@@ -682,7 +684,8 @@ namespace
// Create view to write into result.
MemRefView vRes(result), vParams(params), vIndices(indices);
// Indexed Values
IndexedValue iRes(result), iParams(params), iIndices(indices);
IndexedValue iRes(result), iIndices(indices);
StdIndexedValue iParams(params);
// Construct outer loop for params dims. Exclude the axis dim.
SmallVector<ValueHandle, 4> paramsLbs, paramsUbs;
......@@ -894,7 +897,8 @@ namespace
// Views
MemRefView vRes(result), vArg(arg);
// Index Values
IndexedValue iRes(result), iArg(arg);
StdIndexedValue iRes(result), stdArg(arg);
IndexedValue affineArg(arg);
// Bounds Index Handles
auto resLbs = vRes.getLbs();
auto resUbs = vRes.getUbs();
......@@ -944,9 +948,9 @@ namespace
ValueHandle newRedIdx =
std::is_same<RedOp, NGArgMinRedOp>()
? edsc::intrinsics::select(
iArg(allIVs) < iArg(tempIVs), allIVs[axis], currRedIdx)
affineArg(allIVs) < stdArg(tempIVs), allIVs[axis], currRedIdx)
: edsc::intrinsics::select(
iArg(tempIVs) < iArg(allIVs), allIVs[axis], currRedIdx);
stdArg(tempIVs) < affineArg(allIVs), allIVs[axis], currRedIdx);
iRes(nonRedIVs) = ValueHandle::create<IndexCastOp>(newRedIdx, resTy);
});
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment