Commit ced18814 authored by Diego Caballero's avatar Diego Caballero

[MLIR] Replace ad-hoc index-to-int casting with index cast op

parent a875fc8a
...@@ -21,7 +21,6 @@ set(SRC ...@@ -21,7 +21,6 @@ set(SRC
compiler.cpp compiler.cpp
lowerer.cpp lowerer.cpp
memory_manager.cpp memory_manager.cpp
helpers.cpp
pass/mlir_subgraph_extraction.cpp pass/mlir_subgraph_extraction.cpp
pass/mlir_subgraph_extraction.hpp pass/mlir_subgraph_extraction.hpp
) )
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <mlir/ExecutionEngine/MemRefUtils.h>
#include <stdint.h>
#include "ngraph/ngraph_visibility.hpp"
/// Call back to copy Index tensor to Int tensor
/// Can handle int tensors of bitwidth 8, 16, 32 and 64
/// Index width is always intptr_t
extern "C" NGRAPH_API void __mlir_convert_index_to_int(mlir::StaticFloatMemRef dst,
mlir::StaticFloatMemRef src,
size_t numElements,
size_t intWidth)
{
size_t indexSize = sizeof(intptr_t);
auto pSrc = reinterpret_cast<intptr_t*>(src.data);
auto pDst = reinterpret_cast<char*>(dst.data);
for (auto i = 0; i < numElements; i++)
{
switch (intWidth)
{
case 8:
*pDst = static_cast<char>(pSrc[i]);
pDst++;
break;
case 16:
*(short*)pDst = static_cast<short>(pSrc[i]);
pDst += sizeof(short);
break;
case 32:
*(int*)pDst = static_cast<int>(pSrc[i]);
pDst += sizeof(int);
break;
case 64:
*(long*)pDst = static_cast<long>(pSrc[i]);
pDst += sizeof(long);
break;
}
}
}
...@@ -66,8 +66,7 @@ namespace ...@@ -66,8 +66,7 @@ namespace
void lowerIndexReduction(Operation* op, void lowerIndexReduction(Operation* op,
ArrayRef<Value*> operands, ArrayRef<Value*> operands,
PatternRewriter& rewriter, PatternRewriter& rewriter,
DialectLoweringPass& m_pass, DialectLoweringPass& m_pass);
bool isMin);
/// Conversion from types in the nGraph dialect to the Standard dialect. /// Conversion from types in the nGraph dialect to the Standard dialect.
class NGraphTypeConverter : public TypeConverter class NGraphTypeConverter : public TypeConverter
...@@ -386,13 +385,13 @@ namespace ...@@ -386,13 +385,13 @@ namespace
REWRITER(NGArgMaxRedOp) REWRITER(NGArgMaxRedOp)
{ {
lowerIndexReduction<mlir::NGArgMaxRedOp>(op, operands, rewriter, m_pass, false); lowerIndexReduction<mlir::NGArgMaxRedOp>(op, operands, rewriter, m_pass);
return matchSuccess(); return matchSuccess();
} }
REWRITER(NGArgMinRedOp) REWRITER(NGArgMinRedOp)
{ {
lowerIndexReduction<mlir::NGArgMinRedOp>(op, operands, rewriter, m_pass, true); lowerIndexReduction<mlir::NGArgMinRedOp>(op, operands, rewriter, m_pass);
return matchSuccess(); return matchSuccess();
} }
...@@ -468,16 +467,18 @@ namespace ...@@ -468,16 +467,18 @@ namespace
#undef REWRITER #undef REWRITER
template <typename T> template <typename RedOp>
void lowerIndexReduction(Operation* op, void lowerIndexReduction(Operation* op,
ArrayRef<Value*> operands, ArrayRef<Value*> operands,
PatternRewriter& rewriter, PatternRewriter& rewriter,
DialectLoweringPass& m_pass, DialectLoweringPass& m_pass)
bool isMin)
{ {
T argmin = cast<T>(op); static_assert(std::is_same<RedOp, NGArgMinRedOp>() || std::is_same<RedOp, NGArgMaxRedOp>(),
auto loc = argmin.getLoc(); "Template parameter is not supported by lowerIndexReduction");
auto axesAttr = argmin.axes();
RedOp redOp = cast<RedOp>(op);
auto loc = redOp.getLoc();
auto axesAttr = redOp.axes();
NGRAPH_CHECK(axesAttr.size() == 1, "Index Reduction op should have one reduction axis"); NGRAPH_CHECK(axesAttr.size() == 1, "Index Reduction op should have one reduction axis");
Attribute axisAttr = *axesAttr.begin(); Attribute axisAttr = *axesAttr.begin();
...@@ -489,19 +490,8 @@ namespace ...@@ -489,19 +490,8 @@ namespace
// Retrieve/generate Values for operands and result. // Retrieve/generate Values for operands and result.
ScopedContext scope(rewriter, loc); ScopedContext scope(rewriter, loc);
Value* arg = operands[0]; Value* arg = operands[0];
auto arg_type = arg->getType().cast<MemRefType>();
Value* result = m_pass.buildOutputDefs(op, rewriter)[0];
Value* finalResult = m_pass.buildOutputDefs(op, rewriter)[0];
Type type = argmin.getResult()->getType();
NGTensorType resultTy = type.cast<NGTensorType>();
// MLIR doesn't support Index to/from Integer type-conversion
// We have to store our result in an IndexType tensor and call-back to a type-conversion routine in nGraph
// TODO: Fix this once MLIR provides explicit cast operations.
Value* result = m_pass.createTempTensor(
rewriter.getMemRefType(resultTy.getShape(), rewriter.getIndexType()),
resultTy.getNumElements() *
sizeof(intptr_t), /* hacky way to get target-dependent size of IndexType */
rewriter);
// Views // Views
MemRefView vRes(result), vArg(arg); MemRefView vRes(result), vArg(arg);
...@@ -512,73 +502,60 @@ namespace ...@@ -512,73 +502,60 @@ namespace
auto resUbs = vRes.getUbs(); auto resUbs = vRes.getUbs();
auto argLbs = vArg.getLbs(); auto argLbs = vArg.getLbs();
auto argUbs = vArg.getUbs(); auto argUbs = vArg.getUbs();
Type resTy = result->getType().cast<MemRefType>().getElementType();
// Generate loop nest that initializes result to lower bound of the axis to be reduced.
{ {
// Loop induction vars
auto ivs = IndexHandle::makeIndexHandles(vRes.rank()); auto ivs = IndexHandle::makeIndexHandles(vRes.rank());
auto pivs = IndexHandle::makeIndexHandlePointers(ivs); auto pivs = IndexHandle::makeIndexHandlePointers(ivs);
// Steps
auto steps = vRes.getSteps(); auto steps = vRes.getSteps();
auto initVal = vArg.lb(axis); auto initVal = vArg.lb(axis);
// clang-format off
LoopNestBuilder(pivs, resLbs, resUbs, steps)( LoopNestBuilder(pivs, resLbs, resUbs, steps)(
// single stmt body [&] { iRes(ivs) = ValueHandle::create<IndexCastOp>(initVal, resTy); });
[&] {
iRes(ivs) = initVal;
}
);
} }
// reduction loops // Generate loop nest that computes the actual index reduction.
{ {
auto allIVs = IndexHandle::makeIndexHandles(vArg.rank()); auto allIVs = IndexHandle::makeIndexHandles(vArg.rank());
auto pAllIVs = IndexHandle::makeIndexHandlePointers(allIVs); auto pAllIVs = IndexHandle::makeIndexHandlePointers(allIVs);
SmallVector<IndexHandle,8> nonRedIVs;
auto steps = vArg.getSteps(); auto steps = vArg.getSteps();
SmallVector<IndexHandle, 8> nonRedIVs;
Type resTy = result->getType().cast<MemRefType>().getElementType();
NGRAPH_CHECK(resTy.isa<IntegerType>(),
"Expected integer result type in index reduction");
// iterate over all argument dimensions // iterate over all argument dimensions
LoopNestBuilder(pAllIVs, argLbs, argUbs, steps)( LoopNestBuilder(pAllIVs, argLbs, argUbs, steps)([&] {
[&] {
// build a list of non-reduction IVs // build a list of non-reduction IVs
for (auto i = 0; i < vArg.rank(); i++) for (auto i = 0; i < vArg.rank(); i++)
{ {
if (i != axis) if (i != axis)
nonRedIVs.push_back(allIVs[i]); nonRedIVs.push_back(allIVs[i]);
} }
// load current min index
ValueHandle currMinIndx = iRes(nonRedIVs); // Load current min index with integer data type and convert it to index data type.
ValueHandle currRedIdx = ValueHandle::create<IndexCastOp>(
(ValueHandle)iRes(nonRedIVs), IndexType::get(resTy.getContext()));
// Build list of IVs including current min index.
auto tempIVs = allIVs; auto tempIVs = allIVs;
// build list of IVs including current min index tempIVs[axis] = currRedIdx;
tempIVs[axis] = currMinIndx;
iRes(nonRedIVs) = isMin ? edsc::intrinsics::select(iArg(allIVs) < iArg(tempIVs), allIVs[axis], currMinIndx) :
edsc::intrinsics::select(iArg(tempIVs) < iArg(allIVs), allIVs[axis], currMinIndx);
}
);
}
// Call-back to convert Index tensor to Integer tensor // Select the min/max value and cast it back to integer type before storing it.
auto callBackFunc = m_pass.getCallDecl("__mlir_convert_index_to_int", ValueHandle newRedIdx =
{finalResult->getType(), result->getType(), rewriter.getIndexType(), rewriter.getIndexType()}, std::is_same<RedOp, NGArgMinRedOp>()
{}, ? edsc::intrinsics::select(
rewriter); iArg(allIVs) < iArg(tempIVs), allIVs[axis], currRedIdx)
: edsc::intrinsics::select(
iArg(tempIVs) < iArg(allIVs), allIVs[axis], currRedIdx);
SmallVector<mlir::Value*, 4> args = {finalResult, /* dst tensor */ iRes(nonRedIVs) = ValueHandle::create<IndexCastOp>(newRedIdx, resTy);
result, /* src tensor */ });
/* Num of Elements */ }
rewriter.create<mlir::ConstantIndexOp>(
rewriter.getUnknownLoc(),
resultTy.getNumElements()
),
/* Integer size used in final result*/
rewriter.create<mlir::ConstantIndexOp>(
rewriter.getUnknownLoc(),
resultTy.getElementType().cast<NGIntegerType>().getWidth()
)
};
rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args);
rewriter.replaceOp(op, {finalResult}); rewriter.replaceOp(op, result);
} }
} }
namespace mlir namespace mlir
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment