Commit d0b97e73 authored by Diego Caballero's avatar Diego Caballero Committed by Scott Cyphers

[MLIR] Fix issue with non-affine loads/stores in argmin/armax/gather ops (#3351)

parent 9f928d92
...@@ -43,6 +43,8 @@ namespace ...@@ -43,6 +43,8 @@ namespace
using namespace mlir::edsc::op; using namespace mlir::edsc::op;
using namespace ngraph::runtime; using namespace ngraph::runtime;
using namespace ngraph::runtime::ngmlir; using namespace ngraph::runtime::ngmlir;
// Index notation to generate standard (i.e., non-affine) loads and stores.
using StdIndexedValue = TemplatedIndexedValue<intrinsics::std_load, intrinsics::std_store>;
class DialectLoweringPass; class DialectLoweringPass;
...@@ -682,7 +684,8 @@ namespace ...@@ -682,7 +684,8 @@ namespace
// Create view to write into result. // Create view to write into result.
MemRefView vRes(result), vParams(params), vIndices(indices); MemRefView vRes(result), vParams(params), vIndices(indices);
// Indexed Values // Indexed Values
IndexedValue iRes(result), iParams(params), iIndices(indices); IndexedValue iRes(result), iIndices(indices);
StdIndexedValue iParams(params);
// Construct outer loop for params dims. Exclude the axis dim. // Construct outer loop for params dims. Exclude the axis dim.
SmallVector<ValueHandle, 4> paramsLbs, paramsUbs; SmallVector<ValueHandle, 4> paramsLbs, paramsUbs;
...@@ -894,7 +897,8 @@ namespace ...@@ -894,7 +897,8 @@ namespace
// Views // Views
MemRefView vRes(result), vArg(arg); MemRefView vRes(result), vArg(arg);
// Index Values // Index Values
IndexedValue iRes(result), iArg(arg); StdIndexedValue iRes(result), stdArg(arg);
IndexedValue affineArg(arg);
// Bounds Index Handles // Bounds Index Handles
auto resLbs = vRes.getLbs(); auto resLbs = vRes.getLbs();
auto resUbs = vRes.getUbs(); auto resUbs = vRes.getUbs();
...@@ -944,9 +948,9 @@ namespace ...@@ -944,9 +948,9 @@ namespace
ValueHandle newRedIdx = ValueHandle newRedIdx =
std::is_same<RedOp, NGArgMinRedOp>() std::is_same<RedOp, NGArgMinRedOp>()
? edsc::intrinsics::select( ? edsc::intrinsics::select(
iArg(allIVs) < iArg(tempIVs), allIVs[axis], currRedIdx) affineArg(allIVs) < stdArg(tempIVs), allIVs[axis], currRedIdx)
: edsc::intrinsics::select( : edsc::intrinsics::select(
iArg(tempIVs) < iArg(allIVs), allIVs[axis], currRedIdx); stdArg(tempIVs) < affineArg(allIVs), allIVs[axis], currRedIdx);
iRes(nonRedIVs) = ValueHandle::create<IndexCastOp>(newRedIdx, resTy); iRes(nonRedIVs) = ValueHandle::create<IndexCastOp>(newRedIdx, resTy);
}); });
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment