Unverified Commit e762203e authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge pull request #3021 from NervanaSystems/dcaballe/argmin

[MLIR] Add ArgMin/ArgMax lowering support
parents 150250b0 a3768ee4
......@@ -24,8 +24,11 @@
#include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/util/index_reduction.hpp"
#include "ngraph/type/element_type.hpp"
#include <llvm/ADT/STLExtras.h>
......@@ -110,12 +113,12 @@ void MLIRCompiler::build_ng_dialect_module()
for (auto input : kernel_inputs)
{
args_type_list.push_back(get_mlir_type(input->get_output_tensor_ptr().get()));
args_type_list.push_back(get_mlir_type(input.get()));
}
for (auto output : kernel_outputs)
{
result_type_list.push_back(get_mlir_type(output->get_output_tensor_ptr().get()));
result_type_list.push_back(get_mlir_type(output.get()));
}
auto func_type = mlir::FunctionType::get(args_type_list, result_type_list, &m_context);
......@@ -146,17 +149,23 @@ void MLIRCompiler::build_ng_dialect_module()
dump_mlir_module("nGraph Dialect Dump:");
}
// Converts an nGraph Tensor into an MLIR tensor type, including the conversion of the Tensor's
// element type.
mlir::Type MLIRCompiler::get_mlir_type(const descriptor::Tensor* tensor)
// Converts nGraph shape \p ng_shape to MLIR shape \p mlir_shape.
static void get_mlir_shape(ngraph::Shape ng_shape, llvm::SmallVectorImpl<int64_t>& mlir_shape)
{
SmallVector<int64_t, 4> shape;
for (auto d : tensor->get_shape())
for (auto dim : ng_shape)
{
shape.push_back(d);
mlir_shape.push_back(dim);
}
}
return mlir::NGTensorType::get(&m_context, get_mlir_type(tensor->get_element_type()), shape);
// Converts an nGraph Tensor into an MLIR tensor type, including the conversion of the Tensor's
// element type.
mlir::Type MLIRCompiler::get_mlir_type(const descriptor::Tensor* tensor)
{
SmallVector<int64_t, 4> mlir_shape;
get_mlir_shape(tensor->get_shape(), mlir_shape);
return mlir::NGTensorType::get(
&m_context, get_mlir_type(tensor->get_element_type()), mlir_shape);
}
// Converts an nGraph element type into an MLIR type.
......@@ -195,6 +204,12 @@ mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type)
#endif
}
mlir::Type MLIRCompiler::get_mlir_type(const ngraph::Node* node)
{
descriptor::Tensor* out_tensor = node->get_output_tensor_ptr().get();
return get_mlir_type(out_tensor);
}
void MLIRCompiler::update_tensor_value(descriptor::Tensor* tensor, mlir::Value* value)
{
NGRAPH_CHECK(m_tensor_to_value_map.find(tensor) == m_tensor_to_value_map.end(),
......@@ -280,6 +295,17 @@ namespace ngraph
return compiler.create_binary_op<mlir::NGAddOp>(ng_node);
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMax)
{
return compiler.create_index_reduction<mlir::NGArgMaxRedOp>(ng_node);
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMin)
{
return compiler.create_index_reduction<mlir::NGArgMinRedOp>(ng_node);
}
template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Dot)
{
......@@ -316,6 +342,22 @@ void MLIRCompiler::create_return()
m_builder->create<mlir::NGReturnOp>(mlir::UnknownLoc::get(&m_context), value_list);
}
template <typename RedOp>
mlir::Value* MLIRCompiler::create_index_reduction(const ngraph::Node* ng_node)
{
auto* idx_red = static_cast<const ngraph::op::util::IndexReduction*>(ng_node);
auto arg = idx_red->get_argument(0);
size_t red_axis = idx_red->get_reduction_axis();
mlir::Value* arg_val = get_tensor_value(arg->get_output_tensor_ptr().get()).m_value;
mlir::ArrayAttr red_axes_attr = m_builder->getI64ArrayAttr({(int64_t)red_axis});
return m_builder
->create<RedOp>(
mlir::UnknownLoc::get(&m_context), get_mlir_type(ng_node), arg_val, red_axes_attr)
.getResult();
}
// Binds MLIR function arguments to the proper values. This includes externally allocated tensors
// helpers to be used inside the function.
void MLIRCompiler::bind_arguments()
......@@ -376,10 +418,17 @@ void MLIRCompiler::execute()
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
unsigned opt_level = 3;
if (char* opt_level_str = std::getenv("NGRAPH_MLIR_OPT_LEVEL"))
{
opt_level = std::stoi(opt_level_str);
NGRAPH_CHECK(opt_level >= 0 && opt_level <= 3, "Invalid optimization level");
}
// Create an MLIR execution engine. We use a null MLIR pass manager for now to make sure we
// don't run MLIR passes that were already run. We also pass a default transformer to run
// LLVM optimizations at level 3.
auto llvm_transformer = mlir::makeOptimizingTransformer(3 /*optLevel*/, 0 /*sizeLevel*/);
auto llvm_transformer =
mlir::makeOptimizingTransformer(opt_level /*optLevel*/, 0 /*sizeLevel*/);
auto maybeEngine = mlir::ExecutionEngine::create(m_module.get(), llvm_transformer);
NGRAPH_CHECK(maybeEngine, "failed to construct an execution engine");
m_engine = std::move(maybeEngine.get());
......
......@@ -91,6 +91,8 @@ namespace ngraph
mlir::Type get_mlir_type(const descriptor::Tensor* tensor);
mlir::Type get_mlir_type(const element::Type& type);
mlir::Type get_mlir_type(const ngraph::Node* node);
TensorInfo get_tensor_value(descriptor::Tensor* tensor);
void update_tensor_value(descriptor::Tensor* tensor, mlir::Value* value);
......@@ -106,6 +108,9 @@ namespace ngraph
template <typename BinOp>
mlir::Value* create_binary_op(const ngraph::Node* ng_node);
template <typename RedOp>
mlir::Value* create_index_reduction(const ngraph::Node* ng_node);
void create_return();
/// Helper to create memref arguments for MLIR function signature
......
......@@ -97,7 +97,7 @@ template <typename T>
static mlir::LogicalResult verifyIndexReductionOp(T* op)
{
// TODO: verifyAxisReductionOp(op) + return element type + single axis.
return mlir::failure();
return mlir::success();
}
template <typename T>
......
......@@ -160,6 +160,7 @@ namespace mlir
static bool kindof(unsigned kind) { return kind == NGTypeKind::NG_BOOL_TYPE_ID; }
static NGBoolType get(mlir::MLIRContext* ctx) { return get(NG_BOOL_TYPE_ID, ctx); }
size_t getWidth() { return 8; }
};
// Note that dialect types don't add new data members, so always possible
......@@ -222,6 +223,23 @@ namespace mlir
int getRank() { return getShape().size(); }
/// Computes tensor size in bytes
size_t getSizeInBytes()
{
return getNumElements() * llvm::divideCeil(getElementBitWidth(), 8);
}
size_t getElementBitWidth()
{
Type type = getElementType();
if (NGIntegerType intType = type.dyn_cast<NGIntegerType>())
return intType.getWidth();
if (NGFloatType floatType = type.dyn_cast<NGFloatType>())
return floatType.getIntOrFloatBitWidth();
if (NGBoolType boolType = type.dyn_cast<NGBoolType>())
return boolType.getWidth();
NGRAPH_CHECK(false, "Unknown type");
return -1;
}
/// Get number of elements
size_t getNumElements()
{
size_t s = 1;
auto shape = getShape();
......@@ -232,10 +250,8 @@ namespace mlir
return -1;
s *= shape[i];
}
// Multiply times element size
return s * llvm::divideCeil(getElementType().getIntOrFloatBitWidth(), 8);
return s;
}
/// Checks if two tensors are compatible. Compatible means:
/// Exactly same element types
/// Compatible shapes: see isCompatibleShape.
......
......@@ -37,7 +37,9 @@ namespace
{
using namespace mlir;
using namespace mlir::edsc;
using namespace mlir::edsc::op;
using namespace ngraph::runtime;
using namespace ngraph::runtime::ngmlir;
class DialectLoweringPass;
......@@ -59,6 +61,13 @@ namespace
#include "op_lowerers.inc"
// Helpers
template <typename RedOp>
void lowerIndexReduction(Operation* op,
ArrayRef<Value*> operands,
PatternRewriter& rewriter,
DialectLoweringPass& m_pass);
/// Conversion from types in the nGraph dialect to the Standard dialect.
class NGraphTypeConverter : public TypeConverter
{
......@@ -82,15 +91,17 @@ namespace
void runOnModule() override;
SmallVector<Value*, 4> buildOutputDefs(Operation* op, PatternRewriter& rewriter);
private:
/// Collect a set of patterns to convert from the nGraph dialect to Affine dialect.
void populateNGraphToAffineConversionPatterns(OwningRewritePatternList& patterns);
Value* createTempTensor(Type type, unsigned size, PatternRewriter& rewriter);
mlir::Function* getCallDecl(StringRef name,
ArrayRef<Type> args,
ArrayRef<Type> output,
PatternRewriter& rewriter);
private:
/// Collect a set of patterns to convert from the nGraph dialect to Affine dialect.
void populateNGraphToAffineConversionPatterns(OwningRewritePatternList& patterns);
void findOutputValues();
void processFakeInstrs();
Value* insertMemMgrDef(PatternRewriter* rewriter = nullptr);
......@@ -136,8 +147,11 @@ namespace
void DialectLoweringPass::populateNGraphToAffineConversionPatterns(
OwningRewritePatternList& patterns)
{
RewriteListBuilder<NGAddOpConversion, NGDotOpConversion, NGReturnOpConversion>::build(
patterns, &getContext(), *this);
RewriteListBuilder<NGAddOpConversion,
NGArgMaxRedOpConversion,
NGArgMinRedOpConversion,
NGDotOpConversion,
NGReturnOpConversion>::build(patterns, &getContext(), *this);
}
void DialectLoweringPass::findOutputValues()
......@@ -206,25 +220,30 @@ namespace
else
{
auto tensorType = origResult->getType().cast<NGTensorType>();
auto callBackFunc = getCallDecl("__mlir_allocate",
{rewriter.getIndexType(), rewriter.getIndexType()},
{m_typeConverter.convertType(tensorType)},
rewriter);
auto size = tensorType.getSizeInBytes();
SmallVector<mlir::Value*, 4> args = {
insertMemMgrDef(&rewriter), /* pointer to mem manager */
rewriter.create<mlir::ConstantIndexOp>(rewriter.getUnknownLoc(),
size)}; /* size to allocate */
auto newResult =
rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args)
.getResult(0);
auto newResult = createTempTensor(
m_typeConverter.convertType(tensorType), tensorType.getSizeInBytes(), rewriter);
newResults.push_back(newResult);
}
}
return newResults;
}
Value*
DialectLoweringPass::createTempTensor(Type type, unsigned size, PatternRewriter& rewriter)
{
auto callBackFunc = getCallDecl("__mlir_allocate",
{rewriter.getIndexType(), rewriter.getIndexType()},
{type},
rewriter);
SmallVector<mlir::Value*, 4> args = {
insertMemMgrDef(&rewriter), /* pointer to mem manager */
rewriter.create<mlir::ConstantIndexOp>(rewriter.getUnknownLoc(),
size)}; /* size to allocate */
auto newTemp = rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args)
.getResult(0);
return newTemp;
}
void DialectLoweringPass::processFakeInstrs()
{
auto context = getModule().getContext();
......@@ -326,7 +345,6 @@ namespace
// ADD
REWRITER(NGAddOp)
{
auto add = cast<NGAddOp>(op);
auto loc = add.getLoc();
......@@ -365,6 +383,18 @@ namespace
return matchSuccess();
}
REWRITER(NGArgMaxRedOp)
{
lowerIndexReduction<mlir::NGArgMaxRedOp>(op, operands, rewriter, m_pass);
return matchSuccess();
}
REWRITER(NGArgMinRedOp)
{
lowerIndexReduction<mlir::NGArgMinRedOp>(op, operands, rewriter, m_pass);
return matchSuccess();
}
REWRITER(NGDotOp)
{
auto dot = cast<NGDotOp>(op);
......@@ -412,7 +442,7 @@ namespace
IndexHandle n_ub(v_lhs.ub(n_dim)), m_ub(v_lhs.ub(m_dim)), k_ub(v_rhs.ub(k_dim));
int64_t n_step = v_lhs.step(n_dim), m_step = v_lhs.step(m_dim), k_step = v_rhs.step(k_dim);
// Constants, indexed values and indexes to be used inside the loop nest.
// Constants and indexed values to be used inside the loop nest.
IndexedValue i_res(result), i_lhs(lhs), i_rhs(rhs);
ValueHandle zero_init(rewriter.create<ConstantOp>(loc, rewriter.getZeroAttr(elem_ty)));
......@@ -436,6 +466,96 @@ namespace
}
#undef REWRITER
template <typename RedOp>
void lowerIndexReduction(Operation* op,
ArrayRef<Value*> operands,
PatternRewriter& rewriter,
DialectLoweringPass& m_pass)
{
static_assert(std::is_same<RedOp, NGArgMinRedOp>() || std::is_same<RedOp, NGArgMaxRedOp>(),
"Template parameter is not supported by lowerIndexReduction");
RedOp redOp = cast<RedOp>(op);
auto loc = redOp.getLoc();
auto axesAttr = redOp.axes();
NGRAPH_CHECK(axesAttr.size() == 1, "Index Reduction op should have one reduction axis");
Attribute axisAttr = *axesAttr.begin();
unsigned axis = axisAttr.dyn_cast<IntegerAttr>().getInt();
NGRAPH_CHECK(operands.size() == 1 && operands[0] != nullptr,
"Expected one non-null operand in Index Reduction op");
// Retrieve/generate Values for operands and result.
ScopedContext scope(rewriter, loc);
Value* arg = operands[0];
Value* result = m_pass.buildOutputDefs(op, rewriter)[0];
// Views
MemRefView vRes(result), vArg(arg);
// Index Values
IndexedValue iRes(result), iArg(arg);
// Bounds Index Handles
auto resLbs = vRes.getLbs();
auto resUbs = vRes.getUbs();
auto argLbs = vArg.getLbs();
auto argUbs = vArg.getUbs();
Type resTy = result->getType().cast<MemRefType>().getElementType();
// Generate loop nest that initializes result to lower bound of the axis to be reduced.
{
auto ivs = IndexHandle::makeIndexHandles(vRes.rank());
auto pivs = IndexHandle::makeIndexHandlePointers(ivs);
auto steps = vRes.getSteps();
auto initVal = vArg.lb(axis);
LoopNestBuilder(pivs, resLbs, resUbs, steps)(
[&] { iRes(ivs) = ValueHandle::create<IndexCastOp>(initVal, resTy); });
}
// Generate loop nest that computes the actual index reduction.
{
auto allIVs = IndexHandle::makeIndexHandles(vArg.rank());
auto pAllIVs = IndexHandle::makeIndexHandlePointers(allIVs);
auto steps = vArg.getSteps();
SmallVector<IndexHandle, 8> nonRedIVs;
Type resTy = result->getType().cast<MemRefType>().getElementType();
NGRAPH_CHECK(resTy.isa<IntegerType>(),
"Expected integer result type in index reduction");
// iterate over all argument dimensions
LoopNestBuilder(pAllIVs, argLbs, argUbs, steps)([&] {
// build a list of non-reduction IVs
for (auto i = 0; i < vArg.rank(); i++)
{
if (i != axis)
nonRedIVs.push_back(allIVs[i]);
}
// Load current min index with integer data type and convert it to index data type.
ValueHandle currRedIdx = ValueHandle::create<IndexCastOp>(
(ValueHandle)iRes(nonRedIVs), IndexType::get(resTy.getContext()));
// Build list of IVs including current min index.
auto tempIVs = allIVs;
tempIVs[axis] = currRedIdx;
// Select the min/max value and cast it back to integer type before storing it.
ValueHandle newRedIdx =
std::is_same<RedOp, NGArgMinRedOp>()
? edsc::intrinsics::select(
iArg(allIVs) < iArg(tempIVs), allIVs[axis], currRedIdx)
: edsc::intrinsics::select(
iArg(tempIVs) < iArg(allIVs), allIVs[axis], currRedIdx);
iRes(nonRedIVs) = ValueHandle::create<IndexCastOp>(newRedIdx, resTy);
});
}
rewriter.replaceOp(op, result);
}
}
namespace mlir
......
......@@ -32,6 +32,8 @@
};
DECL_OP_CONV(NGAddOp)
DECL_OP_CONV(NGArgMaxRedOp)
DECL_OP_CONV(NGArgMinRedOp)
DECL_OP_CONV(NGDotOp)
DECL_OP_CONV(NGReturnOp)
......
......@@ -4,6 +4,8 @@
#endif
MLIR_OP(Add)
MLIR_OP(ArgMin)
MLIR_OP(ArgMax)
MLIR_OP(Dot)
// Add new supported ops here
......
......@@ -15,9 +15,12 @@
//*****************************************************************************
#include "mlir_subgraph_extraction.hpp"
#include "ngraph/assertion.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/get_output_element.hpp"
......@@ -105,6 +108,13 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
return false;
}
}
if (TI(ngraph::op::ArgMin) == TI(*node) || TI(ngraph::op::ArgMax) == TI(*node))
{
// TODO: Remove this when MLIR has float point cmp support
if (!node->input(0).get_element_type().is_integral())
return false;
}
return true;
}
......
......@@ -240,6 +240,10 @@ batch_norm_training_0eps_f32
argmin_trivial
argmax_trivial
argmin_trivial_in_i32
argmin_3D_i32
argmin_3D_i64
argmax_3D_i32
argmax_3D_i64
sum_large_1d_to_scalar
sum_stable_acc
one_hot_scalar_2_in_3
......
......@@ -55,6 +55,107 @@ NGRAPH_TEST(${BACKEND_NAME}, argmin_trivial)
EXPECT_EQ((vector<int>{3, 2, 1}), read_vector<int>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, argmin_2D_i32)
{
Shape shape{4, 3};
Shape rshape{3};
auto A = make_shared<op::Parameter>(element::i32, shape);
auto f = make_shared<Function>(make_shared<op::ArgMin>(A, 0, element::i32), ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::i32, shape);
copy_data(a, vector<int>{12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7});
auto result = backend->create_tensor(element::i32, rshape);
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_EQ((vector<int>{3, 2, 1}), read_vector<int>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, argmin_3D_i32)
{
Shape shape{3, 3, 4};
Shape rshape{3, 4};
auto A = make_shared<op::Parameter>(element::i32, shape);
auto f = make_shared<Function>(make_shared<op::ArgMin>(A, 1, element::i32), ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::i32, shape);
copy_data(a,
test::NDArray<int, 3>({{{12, 2, 10, 9}, {3, 5, 0, 8}, {7, 9, 1, 5}},
{{7, 2, 4, 10}, {6, 10, 2, 2}, {12, 1, 1, 1}},
{{10, 2, 2, 4}, {1, 5, 5, 1}, {7, 12, 2, 2}}})
.get_vector());
auto result = backend->create_tensor(element::i32, rshape);
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_EQ((vector<int>{1, 0, 1, 2, 1, 2, 2, 2, 1, 0, 0, 1}), read_vector<int>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, argmin_3D_i64)
{
Shape shape{3, 3, 4};
Shape rshape{3, 4};
auto A = make_shared<op::Parameter>(element::i32, shape);
auto f = make_shared<Function>(make_shared<op::ArgMin>(A, 1, element::i64), ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::i32, shape);
copy_data(a,
test::NDArray<int, 3>({{{12, 2, 10, 9}, {3, 5, 0, 8}, {7, 9, 1, 5}},
{{7, 2, 4, 10}, {6, 10, 2, 2}, {12, 1, 1, 1}},
{{10, 2, 2, 4}, {1, 5, 5, 1}, {7, 12, 2, 2}}})
.get_vector());
auto result = backend->create_tensor(element::i64, rshape);
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_EQ((vector<int64_t>{1, 0, 1, 2, 1, 2, 2, 2, 1, 0, 0, 1}), read_vector<int64_t>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, argmin_4D_i64)
{
Shape shape{2, 2, 5, 5}; // NCHW ->(0,1,2,3)
Shape rshape{2, 2, 5};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(make_shared<op::ArgMin>(A, 3, element::i64), ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(
a,
test::NDArray<int, 4>(
{{{{3, 1, 1, 2, 105},
{0, 3, 2, 1, 2},
{2, 4, 2, 0, 1},
{2, 5, 1, 1, 22},
{5, 2, 1, 7, 5}},
{{3, 1, 2, 2, 1},
{1, 7, 3, 8, 1},
{2, 10, 1, 3, 2},
{3, 1, 0, 0, 6},
{2, 0, 0, 0, 0}}},
{{{0, 2, 1, 1, 0}, {0, 0, 0, 0, 1}, {0, 0, 1, 0, 3}, {2, 0, 0, 3, 0}, {0, 0, 0, 0, 1}},
{{2, 1, 0, 0, 1},
{0, 2, 0, 0, 0},
{1, 1, 2, 0, 2},
{1, 1, 1, 0, 1},
{1, 0, 0, 0, 2}}}})
.get_vector());
auto result = backend->create_tensor(element::i64, rshape);
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_EQ((vector<int64_t>{1, 0, 3, 2, 2, 1, 0, 2, 2, 1, 0, 0, 0, 1, 0, 2, 0, 3, 3, 1}),
read_vector<int64_t>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, argmin_4D_axis_3_i64)
{
Shape shape{2, 2, 5, 5}; // NCHW ->(0,1,2,3)
......@@ -158,6 +259,107 @@ NGRAPH_TEST(${BACKEND_NAME}, argmax_trivial)
EXPECT_EQ((vector<int>{1, 3, 0}), read_vector<int>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, argmax_2D_i32)
{
Shape shape{4, 3};
Shape rshape{3};
auto A = make_shared<op::Parameter>(element::i32, shape);
auto f = make_shared<Function>(make_shared<op::ArgMax>(A, 0, element::i32), ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::i32, shape);
copy_data(a, vector<int>{12, 2, 10, 9, 8, 4, 6, 1, 5, 3, 11, 7});
auto result = backend->create_tensor(element::i32, rshape);
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_EQ((vector<int>{0, 3, 0}), read_vector<int>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, argmax_3D_i32)
{
Shape shape{3, 3, 4};
Shape rshape{3, 4};
auto A = make_shared<op::Parameter>(element::i32, shape);
auto f = make_shared<Function>(make_shared<op::ArgMax>(A, 1, element::i32), ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::i32, shape);
copy_data(a,
test::NDArray<int, 3>({{{12, 2, 10, 9}, {3, 5, 0, 8}, {7, 9, 1, 5}},
{{7, 2, 4, 10}, {6, 10, 2, 2}, {12, 1, 1, 1}},
{{10, 2, 2, 4}, {1, 5, 5, 1}, {7, 12, 2, 2}}})
.get_vector());
auto result = backend->create_tensor(element::i32, rshape);
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_EQ((vector<int>{0, 2, 0, 0, 2, 1, 0, 0, 0, 2, 1, 0}), read_vector<int>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, argmax_3D_i64)
{
Shape shape{3, 3, 4};
Shape rshape{3, 4};
auto A = make_shared<op::Parameter>(element::i32, shape);
auto f = make_shared<Function>(make_shared<op::ArgMax>(A, 1, element::i64), ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::i32, shape);
copy_data(a,
test::NDArray<int, 3>({{{12, 2, 10, 9}, {3, 5, 0, 8}, {7, 9, 1, 5}},
{{7, 2, 4, 10}, {6, 10, 2, 2}, {12, 1, 1, 1}},
{{10, 2, 2, 4}, {1, 5, 5, 1}, {7, 12, 2, 2}}})
.get_vector());
auto result = backend->create_tensor(element::i64, rshape);
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_EQ((vector<int64_t>{0, 2, 0, 0, 2, 1, 0, 0, 0, 2, 1, 0}), read_vector<int64_t>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, argmax_4D_i64)
{
Shape shape{2, 2, 5, 5}; // NCHW ->(0,1,2,3)
Shape rshape{2, 2, 5};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto f = make_shared<Function>(make_shared<op::ArgMax>(A, 3, element::i64), ParameterVector{A});
auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape);
copy_data(
a,
test::NDArray<int, 4>(
{{{{3, 1, 1, 2, 105},
{0, 3, 2, 1, 2},
{2, 4, 2, 0, 1},
{2, 5, 1, 1, 22},
{5, 2, 1, 7, 5}},
{{3, 1, 2, 2, 1},
{1, 7, 3, 8, 1},
{2, 10, 1, 3, 2},
{3, 1, 0, 0, 6},
{2, 0, 0, 0, 0}}},
{{{0, 2, 1, 1, 0}, {0, 0, 0, 0, 1}, {0, 0, 1, 0, 3}, {2, 0, 0, 3, 0}, {0, 0, 0, 0, 1}},
{{2, 1, 0, 0, 1},
{0, 2, 0, 0, 0},
{1, 1, 2, 0, 2},
{1, 1, 1, 0, 1},
{1, 0, 0, 0, 2}}}})
.get_vector());
auto result = backend->create_tensor(element::i64, rshape);
auto handle = backend->compile(f);
handle->call_with_validate({result}, {a});
EXPECT_EQ((vector<int64_t>{4, 1, 1, 4, 3, 0, 3, 1, 4, 0, 1, 4, 4, 3, 4, 0, 1, 2, 0, 4}),
read_vector<int64_t>(result));
}
NGRAPH_TEST(${BACKEND_NAME}, argmax_3D_axis_0) // Along Channels
{
Shape shape{3, 4, 2}; // CHW ->(0,1,2)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment