Commit 6e672209 authored by nmostafa's avatar nmostafa

Add support for any rank ArgMin. Hacky workaround for IndexType tensors by doing…

Add support for any rank ArgMin. Hacky workaround for IndexType tensors by doing type-conversion after operation
parent 64b43082
...@@ -21,6 +21,7 @@ set(SRC ...@@ -21,6 +21,7 @@ set(SRC
compiler.cpp compiler.cpp
lowerer.cpp lowerer.cpp
memory_manager.cpp memory_manager.cpp
helpers.cpp
pass/mlir_subgraph_extraction.cpp pass/mlir_subgraph_extraction.cpp
pass/mlir_subgraph_extraction.hpp pass/mlir_subgraph_extraction.hpp
) )
......
...@@ -204,14 +204,6 @@ mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type) ...@@ -204,14 +204,6 @@ mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type)
mlir::Type MLIRCompiler::get_mlir_type(const ngraph::Node* node) mlir::Type MLIRCompiler::get_mlir_type(const ngraph::Node* node)
{ {
descriptor::Tensor* out_tensor = node->get_output_tensor_ptr().get(); descriptor::Tensor* out_tensor = node->get_output_tensor_ptr().get();
if (TI(*node) == TI(ngraph::op::ArgMin))
{
SmallVector<int64_t, 4> mlir_shape;
get_mlir_shape(out_tensor->get_shape(), mlir_shape);
return mlir::NGTensorType::get(&m_context, mlir::IndexType::get(&m_context), mlir_shape);
}
return get_mlir_type(out_tensor); return get_mlir_type(out_tensor);
} }
...@@ -404,8 +396,8 @@ void MLIRCompiler::execute() ...@@ -404,8 +396,8 @@ void MLIRCompiler::execute()
// Create an MLIR execution engine. We use a null MLIR pass manager for now to make sure we // Create an MLIR execution engine. We use a null MLIR pass manager for now to make sure we
// don't run MLIR passes that were already run. We also pass a default transformer to run // don't run MLIR passes that were already run. We also pass a default transformer to run
// LLVM optimizations at level 3. // LLVM optimizations at level 3.
auto llvm_transformer = mlir::makeOptimizingTransformer(3 /*optLevel*/, 0 /*sizeLevel*/); //auto llvm_transformer = mlir::makeOptimizingTransformer(0 /*optLevel*/, 0 /*sizeLevel*/);
auto maybeEngine = mlir::ExecutionEngine::create(m_module.get(), llvm_transformer); auto maybeEngine = mlir::ExecutionEngine::create(m_module.get(), nullptr);
NGRAPH_CHECK(maybeEngine, "failed to construct an execution engine"); NGRAPH_CHECK(maybeEngine, "failed to construct an execution engine");
m_engine = std::move(maybeEngine.get()); m_engine = std::move(maybeEngine.get());
......
...@@ -51,6 +51,7 @@ namespace mlir ...@@ -51,6 +51,7 @@ namespace mlir
// reuse std float types as-is // reuse std float types as-is
using NGFloatType = mlir::FloatType; using NGFloatType = mlir::FloatType;
using NGIndexType = mlir::IndexType;
/// Integer type. It represents an integer of width 8,16,32,64. Signed or not. /// Integer type. It represents an integer of width 8,16,32,64. Signed or not.
class NGIntegerType : public mlir::Type::TypeBase<NGIntegerType, mlir::Type> class NGIntegerType : public mlir::Type::TypeBase<NGIntegerType, mlir::Type>
...@@ -160,6 +161,7 @@ namespace mlir ...@@ -160,6 +161,7 @@ namespace mlir
static bool kindof(unsigned kind) { return kind == NGTypeKind::NG_BOOL_TYPE_ID; } static bool kindof(unsigned kind) { return kind == NGTypeKind::NG_BOOL_TYPE_ID; }
static NGBoolType get(mlir::MLIRContext* ctx) { return get(NG_BOOL_TYPE_ID, ctx); } static NGBoolType get(mlir::MLIRContext* ctx) { return get(NG_BOOL_TYPE_ID, ctx); }
size_t getWidth() { return 8; }
}; };
// Note that dialect types don't add new data members, so always possible // Note that dialect types don't add new data members, so always possible
...@@ -222,6 +224,25 @@ namespace mlir ...@@ -222,6 +224,25 @@ namespace mlir
int getRank() { return getShape().size(); } int getRank() { return getShape().size(); }
/// Computes tensor size in bytes /// Computes tensor size in bytes
size_t getSizeInBytes() size_t getSizeInBytes()
{
return getNumElements() * llvm::divideCeil(getElementBitWidth(), 8);
}
size_t getElementBitWidth()
{
Type type = getElementType();
if (NGIntegerType intType = type.dyn_cast<NGIntegerType>())
return intType.getWidth();
if (NGFloatType floatType = type.dyn_cast<NGFloatType>())
return floatType.getIntOrFloatBitWidth();
if (NGIndexType indexType = type.dyn_cast<NGIndexType>())
return sizeof(intptr_t);
if (NGBoolType boolType = type.dyn_cast<NGBoolType>())
return boolType.getWidth();
NGRAPH_FAIL() << "Unknown type";
return -1;
}
/// Get number of elements
size_t getNumElements()
{ {
size_t s = 1; size_t s = 1;
auto shape = getShape(); auto shape = getShape();
...@@ -232,10 +253,8 @@ namespace mlir ...@@ -232,10 +253,8 @@ namespace mlir
return -1; return -1;
s *= shape[i]; s *= shape[i];
} }
// Multiply times element size return s;
return s * llvm::divideCeil(getElementType().getIntOrFloatBitWidth(), 8);
} }
/// Checks if two tensors are compatible. Compatible means: /// Checks if two tensors are compatible. Compatible means:
/// Exactly same element types /// Exactly same element types
/// Compatible shapes: see isCompatibleShape. /// Compatible shapes: see isCompatibleShape.
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <stdint.h>
#include "ngraph/ngraph_visibility.hpp"
#include <mlir/ExecutionEngine/MemRefUtils.h>
/// Call back to copy Index tensor to Int tensor
/// Can handle int tensors of bitwidth 8, 16, 32 and 64
/// Index width is always intptr_t
extern "C" NGRAPH_API void* __mlir_convert_index_to_int(mlir::StaticFloatMemRef dst, mlir::StaticFloatMemRef src, size_t numElements, size_t intWidth)
{
size_t indexSize = sizeof(intptr_t);
auto pSrc = reinterpret_cast<intptr_t*>(src.data);
auto pDst = reinterpret_cast<char*>(dst.data);
for (auto i = 0; i < numElements; i++)
{
switch(intWidth)
{
case 8:
*pDst = static_cast<char>(pSrc[i]);
pDst++;
break;
case 16:
*(short*)pDst = static_cast<short>(pSrc[i]);
pDst += sizeof(short);
break;
case 32:
*(int*)pDst = static_cast<int>(pSrc[i]);
pDst += sizeof(int);
break;
case 64:
*(long*)pDst = static_cast<long>(pSrc[i]);
pDst += sizeof(long);
break;
}
}
}
\ No newline at end of file
...@@ -83,12 +83,13 @@ namespace ...@@ -83,12 +83,13 @@ namespace
} }
void runOnModule() override; void runOnModule() override;
SmallVector<Value*, 4> buildOutputDefs(Operation* op, PatternRewriter& rewriter); SmallVector<Value*, 4> buildOutputDefs(Operation* op, PatternRewriter& rewriter);
Value* createTempTensor(Type type, unsigned size, PatternRewriter& rewriter);
private:
mlir::Function* getCallDecl(StringRef name, mlir::Function* getCallDecl(StringRef name,
ArrayRef<Type> args, ArrayRef<Type> args,
ArrayRef<Type> output, ArrayRef<Type> output,
PatternRewriter& rewriter); PatternRewriter& rewriter);
private:
void findOutputValues(); void findOutputValues();
void processFakeInstrs(); void processFakeInstrs();
Value* insertMemMgrDef(PatternRewriter* rewriter = nullptr); Value* insertMemMgrDef(PatternRewriter* rewriter = nullptr);
...@@ -183,23 +184,27 @@ namespace ...@@ -183,23 +184,27 @@ namespace
else else
{ {
auto tensorType = origResult->getType().cast<NGTensorType>(); auto tensorType = origResult->getType().cast<NGTensorType>();
auto newResult = createTempTensor(m_dialectLowerer.convertType(tensorType), tensorType.getSizeInBytes(), rewriter);
newResults.push_back(newResult);
}
}
return newResults;
}
Value* DialectLoweringPass::createTempTensor(Type type, unsigned size, PatternRewriter& rewriter)
{
auto callBackFunc = getCallDecl("__mlir_allocate", auto callBackFunc = getCallDecl("__mlir_allocate",
{rewriter.getIndexType(), rewriter.getIndexType()}, {rewriter.getIndexType(), rewriter.getIndexType()},
{m_dialectLowerer.convertType(tensorType)}, {type},
rewriter); rewriter);
auto size = tensorType.getSizeInBytes();
SmallVector<mlir::Value*, 4> args = { SmallVector<mlir::Value*, 4> args = {
insertMemMgrDef(&rewriter), /* pointer to mem manager */ insertMemMgrDef(&rewriter), /* pointer to mem manager */
rewriter.create<mlir::ConstantIndexOp>(rewriter.getUnknownLoc(), rewriter.create<mlir::ConstantIndexOp>(rewriter.getUnknownLoc(),
size)}; /* size to allocate */ size)}; /* size to allocate */
auto newResult = auto newTemp =
rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args) rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args)
.getResult(0); .getResult(0);
newResults.push_back(newResult); return newTemp;
}
}
return newResults;
} }
void DialectLoweringPass::processFakeInstrs() void DialectLoweringPass::processFakeInstrs()
...@@ -406,6 +411,10 @@ namespace ...@@ -406,6 +411,10 @@ namespace
{ {
auto argmin = cast<NGArgMinRedOp>(op); auto argmin = cast<NGArgMinRedOp>(op);
auto loc = argmin.getLoc(); auto loc = argmin.getLoc();
auto axesAttr = argmin.axes();
NGRAPH_ASSERT(axesAttr.size() == 1) << "ArgMin should have one reduction axis";
unsigned axis = axesAttr.begin()->dyn_cast<IntegerAttr>().getInt();
NGRAPH_ASSERT(operands.size() == 1 && operands[0] != nullptr) NGRAPH_ASSERT(operands.size() == 1 && operands[0] != nullptr)
<< "Expected one non-null operand in ArgMin op"; << "Expected one non-null operand in ArgMin op";
...@@ -414,23 +423,94 @@ namespace ...@@ -414,23 +423,94 @@ namespace
ScopedContext scope(rewriter, loc); ScopedContext scope(rewriter, loc);
Value* arg = operands[0]; Value* arg = operands[0];
auto arg_type = arg->getType().cast<MemRefType>(); auto arg_type = arg->getType().cast<MemRefType>();
NGRAPH_ASSERT(arg_type.getRank() == 2) << "Unsupported tensor type in ArgMin op";
//axis = op->getAttr(); Value* finalResult = m_pass.buildOutputDefs(op, rewriter)[0];
//NGRAPH_ASSERT(axis == 0) << "Unsupported axis in ArgMin op"; auto resultTy = argmin.getResult()->getType().cast<NGTensorType>();
Value* result = m_pass.buildOutputDefs(op, rewriter)[0]; // MLIR doesn't support Index to/from Integer type-conversion
//NGRAPH_ASSERT(lhs && rhs && result) << "Unexpected null values in MatmulBiasOp"; // We have to store our result in an IndexType tensor and call-back to a type-conversion routine in nGraph
// TODO: Fix this once MLIR provides explicit cast operations.
Value* result = m_pass.createTempTensor(
rewriter.getMemRefType(resultTy.getShape(),rewriter.getIndexType()),
resultTy.getSizeInBytes(),
rewriter
);
// Views
MemRefView vRes(result), vArg(arg);
// Index Values
IndexedValue iRes(result), iArg(arg);
// Bounds Index Handles
auto resLbs = vRes.getLbs();
auto resUbs = vRes.getUbs();
auto argLbs = vArg.getLbs();
auto argUbs = vArg.getUbs();
{
// Loop induction vars
auto ivs = IndexHandle::makeIndexHandles(vRes.rank());
auto pivs = IndexHandle::makeIndexHandlePointers(ivs);
// Steps
auto steps = vRes.getSteps();
auto initVal = vArg.lb(axis);
// clang-format off
LoopNestBuilder(pivs, resLbs, resUbs, steps)(
// single stmt body
[&] {
iRes(ivs) = initVal;
}
);
}
// reduction loops
{
auto allIVs = IndexHandle::makeIndexHandles(vArg.rank());
auto pAllIVs = IndexHandle::makeIndexHandlePointers(allIVs);
SmallVector<IndexHandle,8> nonRedIVs;
// FIXME: Workaround to the integer to index conversion.
auto res_ty = result->getType().cast<MemRefType>();
Type res_elem_ty = res_ty.getElementType();
//result->setType(
// MemRefType::get(res_ty.getShape(), IndexType::get(res_elem_ty.getContext())));
// Create the following loop nest for argmin operation: auto steps = vArg.getSteps();
// for(i, I, 1)
// for(j, J, 1) // Reduction dimention // iterate over all argument dimensions
// res[j] = select((arg[i, j] < res[j]), i, res[j]) LoopNestBuilder(pAllIVs, argLbs, argUbs, steps)(
[&] {
// build a list of non-reduction IVs
for (auto i = 0; i < vArg.rank(); i++)
{
if (i != axis)
nonRedIVs.push_back(allIVs[i]);
}
// load current min index
ValueHandle currMinIndx = iRes(nonRedIVs);
auto tempIVs = allIVs;
// build list of IVs including current min index
tempIVs[axis] = currMinIndx;
iRes(nonRedIVs) = edsc::intrinsics::select(iArg(allIVs) < iArg(tempIVs), allIVs[axis], currMinIndx);
}
);
}
// Call-back to convert Index tensor to Integer tensor
auto callBackFunc = m_pass.getCallDecl("__mlir_convert_index_to_int",
{finalResult->getType(), result->getType(), rewriter.getIndexType(), rewriter.getIndexType()},
{},
rewriter);
SmallVector<mlir::Value*, 4> args = {finalResult, /* dst tensor */
result, /* src tensor */
/* Num of Elements */
rewriter.create<mlir::ConstantIndexOp>(
rewriter.getUnknownLoc(),
resultTy.getNumElements()
),
/* Integer size used */
rewriter.create<mlir::ConstantIndexOp>(
rewriter.getUnknownLoc(),
resultTy.getElementType().cast<NGIntegerType>().getWidth()
)
};
rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args);
rewriter.replaceOp(op, {finalResult});
#if 0
MemRefView v_res(result), v_arg(arg); MemRefView v_res(result), v_arg(arg);
unsigned n_dim = v_arg.fastestVarying() - 1; unsigned n_dim = v_arg.fastestVarying() - 1;
...@@ -459,8 +539,8 @@ namespace ...@@ -459,8 +539,8 @@ namespace
i_res(m) = edsc::intrinsics::select(i_arg(n, m) < i_arg(curr_res, m), n, curr_res); i_res(m) = edsc::intrinsics::select(i_arg(n, m) < i_arg(curr_res, m), n, curr_res);
}); });
}); });
#endif
rewriter.replaceOp(op, {result});
} }
REWRITER(NGReturnOp) { rewriter.replaceOpWithNewOp<ReturnOp>(op); } REWRITER(NGReturnOp) { rewriter.replaceOpWithNewOp<ReturnOp>(op); }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment