Commit ced18814 authored by Diego Caballero's avatar Diego Caballero

[MLIR] Replace ad-hoc index-to-int casting with index cast op

parent a875fc8a
......@@ -21,7 +21,6 @@ set(SRC
compiler.cpp
lowerer.cpp
memory_manager.cpp
helpers.cpp
pass/mlir_subgraph_extraction.cpp
pass/mlir_subgraph_extraction.hpp
)
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <mlir/ExecutionEngine/MemRefUtils.h>
#include <stdint.h>
#include "ngraph/ngraph_visibility.hpp"
/// Call back to copy Index tensor to Int tensor
/// Can handle int tensors of bitwidth 8, 16, 32 and 64
/// Index width is always intptr_t
extern "C" NGRAPH_API void __mlir_convert_index_to_int(mlir::StaticFloatMemRef dst,
mlir::StaticFloatMemRef src,
size_t numElements,
size_t intWidth)
{
size_t indexSize = sizeof(intptr_t);
auto pSrc = reinterpret_cast<intptr_t*>(src.data);
auto pDst = reinterpret_cast<char*>(dst.data);
for (auto i = 0; i < numElements; i++)
{
switch (intWidth)
{
case 8:
*pDst = static_cast<char>(pSrc[i]);
pDst++;
break;
case 16:
*(short*)pDst = static_cast<short>(pSrc[i]);
pDst += sizeof(short);
break;
case 32:
*(int*)pDst = static_cast<int>(pSrc[i]);
pDst += sizeof(int);
break;
case 64:
*(long*)pDst = static_cast<long>(pSrc[i]);
pDst += sizeof(long);
break;
}
}
}
......@@ -66,8 +66,7 @@ namespace
void lowerIndexReduction(Operation* op,
ArrayRef<Value*> operands,
PatternRewriter& rewriter,
DialectLoweringPass& m_pass,
bool isMin);
DialectLoweringPass& m_pass);
/// Conversion from types in the nGraph dialect to the Standard dialect.
class NGraphTypeConverter : public TypeConverter
......@@ -386,13 +385,13 @@ namespace
REWRITER(NGArgMaxRedOp)
{
lowerIndexReduction<mlir::NGArgMaxRedOp>(op, operands, rewriter, m_pass, false);
lowerIndexReduction<mlir::NGArgMaxRedOp>(op, operands, rewriter, m_pass);
return matchSuccess();
}
REWRITER(NGArgMinRedOp)
{
lowerIndexReduction<mlir::NGArgMinRedOp>(op, operands, rewriter, m_pass, true);
lowerIndexReduction<mlir::NGArgMinRedOp>(op, operands, rewriter, m_pass);
return matchSuccess();
}
......@@ -468,16 +467,18 @@ namespace
#undef REWRITER
template <typename T>
template <typename RedOp>
void lowerIndexReduction(Operation* op,
ArrayRef<Value*> operands,
PatternRewriter& rewriter,
DialectLoweringPass& m_pass,
bool isMin)
DialectLoweringPass& m_pass)
{
T argmin = cast<T>(op);
auto loc = argmin.getLoc();
auto axesAttr = argmin.axes();
static_assert(std::is_same<RedOp, NGArgMinRedOp>() || std::is_same<RedOp, NGArgMaxRedOp>(),
"Template parameter is not supported by lowerIndexReduction");
RedOp redOp = cast<RedOp>(op);
auto loc = redOp.getLoc();
auto axesAttr = redOp.axes();
NGRAPH_CHECK(axesAttr.size() == 1, "Index Reduction op should have one reduction axis");
Attribute axisAttr = *axesAttr.begin();
......@@ -489,19 +490,8 @@ namespace
// Retrieve/generate Values for operands and result.
ScopedContext scope(rewriter, loc);
Value* arg = operands[0];
auto arg_type = arg->getType().cast<MemRefType>();
Value* finalResult = m_pass.buildOutputDefs(op, rewriter)[0];
Type type = argmin.getResult()->getType();
NGTensorType resultTy = type.cast<NGTensorType>();
// MLIR doesn't support Index to/from Integer type-conversion
// We have to store our result in an IndexType tensor and call-back to a type-conversion routine in nGraph
// TODO: Fix this once MLIR provides explicit cast operations.
Value* result = m_pass.createTempTensor(
rewriter.getMemRefType(resultTy.getShape(), rewriter.getIndexType()),
resultTy.getNumElements() *
sizeof(intptr_t), /* hacky way to get target-dependent size of IndexType */
rewriter);
Value* result = m_pass.buildOutputDefs(op, rewriter)[0];
// Views
MemRefView vRes(result), vArg(arg);
......@@ -512,73 +502,60 @@ namespace
auto resUbs = vRes.getUbs();
auto argLbs = vArg.getLbs();
auto argUbs = vArg.getUbs();
Type resTy = result->getType().cast<MemRefType>().getElementType();
// Generate loop nest that initializes result to lower bound of the axis to be reduced.
{
// Loop induction vars
auto ivs = IndexHandle::makeIndexHandles(vRes.rank());
auto pivs = IndexHandle::makeIndexHandlePointers(ivs);
// Steps
auto steps = vRes.getSteps();
auto initVal = vArg.lb(axis);
// clang-format off
LoopNestBuilder(pivs, resLbs, resUbs, steps)(
// single stmt body
[&] {
iRes(ivs) = initVal;
}
);
}
LoopNestBuilder(pivs, resLbs, resUbs, steps)(
[&] { iRes(ivs) = ValueHandle::create<IndexCastOp>(initVal, resTy); });
}
// reduction loops
{
auto allIVs = IndexHandle::makeIndexHandles(vArg.rank());
auto pAllIVs = IndexHandle::makeIndexHandlePointers(allIVs);
SmallVector<IndexHandle,8> nonRedIVs;
// Generate loop nest that computes the actual index reduction.
{
auto allIVs = IndexHandle::makeIndexHandles(vArg.rank());
auto pAllIVs = IndexHandle::makeIndexHandlePointers(allIVs);
auto steps = vArg.getSteps();
SmallVector<IndexHandle, 8> nonRedIVs;
auto steps = vArg.getSteps();
Type resTy = result->getType().cast<MemRefType>().getElementType();
NGRAPH_CHECK(resTy.isa<IntegerType>(),
"Expected integer result type in index reduction");
// iterate over all argument dimensions
LoopNestBuilder(pAllIVs, argLbs, argUbs, steps)(
[&] {
// iterate over all argument dimensions
LoopNestBuilder(pAllIVs, argLbs, argUbs, steps)([&] {
// build a list of non-reduction IVs
for (auto i = 0; i < vArg.rank(); i++)
{
if (i != axis)
nonRedIVs.push_back(allIVs[i]);
}
// load current min index
ValueHandle currMinIndx = iRes(nonRedIVs);
// Load current min index with integer data type and convert it to index data type.
ValueHandle currRedIdx = ValueHandle::create<IndexCastOp>(
(ValueHandle)iRes(nonRedIVs), IndexType::get(resTy.getContext()));
// Build list of IVs including current min index.
auto tempIVs = allIVs;
// build list of IVs including current min index
tempIVs[axis] = currMinIndx;
iRes(nonRedIVs) = isMin ? edsc::intrinsics::select(iArg(allIVs) < iArg(tempIVs), allIVs[axis], currMinIndx) :
edsc::intrinsics::select(iArg(tempIVs) < iArg(allIVs), allIVs[axis], currMinIndx);
}
);
}
tempIVs[axis] = currRedIdx;
// Call-back to convert Index tensor to Integer tensor
auto callBackFunc = m_pass.getCallDecl("__mlir_convert_index_to_int",
{finalResult->getType(), result->getType(), rewriter.getIndexType(), rewriter.getIndexType()},
{},
rewriter);
SmallVector<mlir::Value*, 4> args = {finalResult, /* dst tensor */
result, /* src tensor */
/* Num of Elements */
rewriter.create<mlir::ConstantIndexOp>(
rewriter.getUnknownLoc(),
resultTy.getNumElements()
),
/* Integer size used in final result*/
rewriter.create<mlir::ConstantIndexOp>(
rewriter.getUnknownLoc(),
resultTy.getElementType().cast<NGIntegerType>().getWidth()
)
};
rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args);
rewriter.replaceOp(op, {finalResult});
}
// Select the min/max value and cast it back to integer type before storing it.
ValueHandle newRedIdx =
std::is_same<RedOp, NGArgMinRedOp>()
? edsc::intrinsics::select(
iArg(allIVs) < iArg(tempIVs), allIVs[axis], currRedIdx)
: edsc::intrinsics::select(
iArg(tempIVs) < iArg(allIVs), allIVs[axis], currRedIdx);
iRes(nonRedIVs) = ValueHandle::create<IndexCastOp>(newRedIdx, resTy);
});
}
rewriter.replaceOp(op, result);
}
}
namespace mlir
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment