Commit 955361bb authored by nmostafa's avatar nmostafa

Fix rebase issues. Style-apply

parent 4ef010fc
...@@ -24,8 +24,8 @@ ...@@ -24,8 +24,8 @@
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/argmax.hpp" #include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/dot.hpp" #include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp" #include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/util/index_reduction.hpp" #include "ngraph/op/util/index_reduction.hpp"
...@@ -287,13 +287,13 @@ mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Add) ...@@ -287,13 +287,13 @@ mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Add)
return compiler.create_binary_op<mlir::NGAddOp>(ng_node); return compiler.create_binary_op<mlir::NGAddOp>(ng_node);
} }
template<> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMin) mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMin)
{ {
return compiler.create_index_reduction<mlir::NGArgMinRedOp>(ng_node); return compiler.create_index_reduction<mlir::NGArgMinRedOp>(ng_node);
} }
template<> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMax) mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMax)
{ {
return compiler.create_index_reduction<mlir::NGArgMaxRedOp>(ng_node); return compiler.create_index_reduction<mlir::NGArgMaxRedOp>(ng_node);
...@@ -332,7 +332,7 @@ void MLIRCompiler::create_return() ...@@ -332,7 +332,7 @@ void MLIRCompiler::create_return()
m_builder->create<mlir::NGReturnOp>(mlir::UnknownLoc::get(&m_context), value_list); m_builder->create<mlir::NGReturnOp>(mlir::UnknownLoc::get(&m_context), value_list);
} }
template<typename RedOp> template <typename RedOp>
mlir::Value* MLIRCompiler::create_index_reduction(const ngraph::Node* ng_node) mlir::Value* MLIRCompiler::create_index_reduction(const ngraph::Node* ng_node)
{ {
auto* idx_red = static_cast<const ngraph::op::util::IndexReduction*>(ng_node); auto* idx_red = static_cast<const ngraph::op::util::IndexReduction*>(ng_node);
...@@ -344,10 +344,8 @@ mlir::Value* MLIRCompiler::create_index_reduction(const ngraph::Node* ng_node) ...@@ -344,10 +344,8 @@ mlir::Value* MLIRCompiler::create_index_reduction(const ngraph::Node* ng_node)
mlir::ArrayAttr red_axes_attr = m_builder->getI64ArrayAttr({(int64_t)red_axis}); mlir::ArrayAttr red_axes_attr = m_builder->getI64ArrayAttr({(int64_t)red_axis});
return m_builder return m_builder
->create<RedOp>(mlir::UnknownLoc::get(&m_context), ->create<RedOp>(
get_mlir_type(ng_node), mlir::UnknownLoc::get(&m_context), get_mlir_type(ng_node), arg_val, red_axes_attr)
arg_val,
red_axes_attr)
.getResult(); .getResult();
} }
// Binds MLIR function arguments to the proper values. This includes externally allocated tensors // Binds MLIR function arguments to the proper values. This includes externally allocated tensors
...@@ -409,7 +407,7 @@ void MLIRCompiler::execute() ...@@ -409,7 +407,7 @@ void MLIRCompiler::execute()
if (char* opt_level_str = std::getenv("NGRAPH_MLIR_OPT_LEVEL")) if (char* opt_level_str = std::getenv("NGRAPH_MLIR_OPT_LEVEL"))
{ {
opt_level = std::stoi(opt_level_str); opt_level = std::stoi(opt_level_str);
NGRAPH_CHECK(opt_level >=0 && opt_level <= 3 , "Invalid optimization level"); NGRAPH_CHECK(opt_level >= 0 && opt_level <= 3, "Invalid optimization level");
} }
// Create an MLIR execution engine. We use a null MLIR pass manager for now to make sure we // Create an MLIR execution engine. We use a null MLIR pass manager for now to make sure we
// don't run MLIR passes that were already run. We also pass a default transformer to run // don't run MLIR passes that were already run. We also pass a default transformer to run
......
...@@ -108,7 +108,7 @@ namespace ngraph ...@@ -108,7 +108,7 @@ namespace ngraph
template <typename BinOp> template <typename BinOp>
mlir::Value* create_binary_op(const ngraph::Node* ng_node); mlir::Value* create_binary_op(const ngraph::Node* ng_node);
template<typename RedOp> template <typename RedOp>
mlir::Value* create_index_reduction(const ngraph::Node* ng_node); mlir::Value* create_index_reduction(const ngraph::Node* ng_node);
void create_return(); void create_return();
......
...@@ -235,7 +235,7 @@ namespace mlir ...@@ -235,7 +235,7 @@ namespace mlir
return floatType.getIntOrFloatBitWidth(); return floatType.getIntOrFloatBitWidth();
if (NGBoolType boolType = type.dyn_cast<NGBoolType>()) if (NGBoolType boolType = type.dyn_cast<NGBoolType>())
return boolType.getWidth(); return boolType.getWidth();
NGRAPH_FAIL() << "Unknown type"; NGRAPH_CHECK(false, "Unknown type");
return -1; return -1;
} }
/// Get number of elements /// Get number of elements
......
...@@ -14,21 +14,24 @@ ...@@ -14,21 +14,24 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <mlir/ExecutionEngine/MemRefUtils.h>
#include <stdint.h> #include <stdint.h>
#include "ngraph/ngraph_visibility.hpp" #include "ngraph/ngraph_visibility.hpp"
#include <mlir/ExecutionEngine/MemRefUtils.h>
/// Call back to copy Index tensor to Int tensor /// Call back to copy Index tensor to Int tensor
/// Can handle int tensors of bitwidth 8, 16, 32 and 64 /// Can handle int tensors of bitwidth 8, 16, 32 and 64
/// Index width is always intptr_t /// Index width is always intptr_t
extern "C" NGRAPH_API void __mlir_convert_index_to_int(mlir::StaticFloatMemRef dst, mlir::StaticFloatMemRef src, size_t numElements, size_t intWidth) extern "C" NGRAPH_API void __mlir_convert_index_to_int(mlir::StaticFloatMemRef dst,
mlir::StaticFloatMemRef src,
size_t numElements,
size_t intWidth)
{ {
size_t indexSize = sizeof(intptr_t); size_t indexSize = sizeof(intptr_t);
auto pSrc = reinterpret_cast<intptr_t*>(src.data); auto pSrc = reinterpret_cast<intptr_t*>(src.data);
auto pDst = reinterpret_cast<char*>(dst.data); auto pDst = reinterpret_cast<char*>(dst.data);
for (auto i = 0; i < numElements; i++) for (auto i = 0; i < numElements; i++)
{ {
switch(intWidth) switch (intWidth)
{ {
case 8: case 8:
*pDst = static_cast<char>(pSrc[i]); *pDst = static_cast<char>(pSrc[i]);
......
...@@ -46,8 +46,12 @@ namespace ...@@ -46,8 +46,12 @@ namespace
#include "op_lowerers.inc" #include "op_lowerers.inc"
// Helpers // Helpers
template<typename RedOp> template <typename RedOp>
void lowerIndexReduction(Operation* op, ArrayRef<Value*> operands, PatternRewriter& rewriter, DialectLoweringPass& m_pass, bool isMin); void lowerIndexReduction(Operation* op,
ArrayRef<Value*> operands,
PatternRewriter& rewriter,
DialectLoweringPass& m_pass,
bool isMin);
/// Use Dialect Converson Framework /// Use Dialect Converson Framework
class DialectLowerer : public DialectConversion class DialectLowerer : public DialectConversion
...@@ -89,11 +93,12 @@ namespace ...@@ -89,11 +93,12 @@ namespace
void runOnModule() override; void runOnModule() override;
SmallVector<Value*, 4> buildOutputDefs(Operation* op, PatternRewriter& rewriter); SmallVector<Value*, 4> buildOutputDefs(Operation* op, PatternRewriter& rewriter);
Value* createTempTensor(Type type, unsigned size, PatternRewriter& rewriter); Value* createTempTensor(Type type, unsigned size, PatternRewriter& rewriter);
mlir::Function* getCallDecl(StringRef name, mlir::Function* getCallDecl(StringRef name,
ArrayRef<Type> args, ArrayRef<Type> args,
ArrayRef<Type> output, ArrayRef<Type> output,
PatternRewriter& rewriter); PatternRewriter& rewriter);
private: private:
void findOutputValues(); void findOutputValues();
void processFakeInstrs(); void processFakeInstrs();
...@@ -189,26 +194,28 @@ namespace ...@@ -189,26 +194,28 @@ namespace
else else
{ {
auto tensorType = origResult->getType().cast<NGTensorType>(); auto tensorType = origResult->getType().cast<NGTensorType>();
auto newResult = createTempTensor(m_dialectLowerer.convertType(tensorType), tensorType.getSizeInBytes(), rewriter); auto newResult = createTempTensor(m_dialectLowerer.convertType(tensorType),
tensorType.getSizeInBytes(),
rewriter);
newResults.push_back(newResult); newResults.push_back(newResult);
} }
} }
return newResults; return newResults;
} }
Value* DialectLoweringPass::createTempTensor(Type type, unsigned size, PatternRewriter& rewriter) Value*
DialectLoweringPass::createTempTensor(Type type, unsigned size, PatternRewriter& rewriter)
{ {
auto callBackFunc = getCallDecl("__mlir_allocate", auto callBackFunc = getCallDecl("__mlir_allocate",
{rewriter.getIndexType(), rewriter.getIndexType()}, {rewriter.getIndexType(), rewriter.getIndexType()},
{type}, {type},
rewriter); rewriter);
SmallVector<mlir::Value*, 4> args = { SmallVector<mlir::Value*, 4> args = {
insertMemMgrDef(&rewriter), /* pointer to mem manager */ insertMemMgrDef(&rewriter), /* pointer to mem manager */
rewriter.create<mlir::ConstantIndexOp>(rewriter.getUnknownLoc(), rewriter.create<mlir::ConstantIndexOp>(rewriter.getUnknownLoc(),
size)}; /* size to allocate */ size)}; /* size to allocate */
auto newTemp = auto newTemp = rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args)
rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args) .getResult(0);
.getResult(0);
return newTemp; return newTemp;
} }
...@@ -424,54 +431,58 @@ namespace ...@@ -424,54 +431,58 @@ namespace
REWRITER(NGReturnOp) { rewriter.replaceOpWithNewOp<ReturnOp>(op); } REWRITER(NGReturnOp) { rewriter.replaceOpWithNewOp<ReturnOp>(op); }
#undef REWRITER #undef REWRITER
template<typename T> template <typename T>
void lowerIndexReduction(Operation* op, ArrayRef<Value*> operands, PatternRewriter& rewriter, DialectLoweringPass& m_pass, bool isMin) void lowerIndexReduction(Operation* op,
{ ArrayRef<Value*> operands,
T argmin = cast<T>(op); PatternRewriter& rewriter,
auto loc = argmin.getLoc(); DialectLoweringPass& m_pass,
auto axesAttr = argmin.axes(); bool isMin)
NGRAPH_CHECK(axesAttr.size() == 1 , "Index Reduction op should have one reduction axis");
Attribute axisAttr = *axesAttr.begin();
unsigned axis = axisAttr.dyn_cast<IntegerAttr>().getInt();
NGRAPH_CHECK(operands.size() == 1 && operands[0] != nullptr,
"Expected one non-null operand in Index Reduction op");
// Retrieve/generate Values for operands and result.
ScopedContext scope(rewriter, loc);
Value* arg = operands[0];
auto arg_type = arg->getType().cast<MemRefType>();
Value* finalResult = m_pass.buildOutputDefs(op, rewriter)[0];
Type type = argmin.getResult()->getType();
NGTensorType resultTy = type.cast<NGTensorType>();
// MLIR doesn't support Index to/from Integer type-conversion
// We have to store our result in an IndexType tensor and call-back to a type-conversion routine in nGraph
// TODO: Fix this once MLIR provides explicit cast operations.
Value* result = m_pass.createTempTensor(
rewriter.getMemRefType(resultTy.getShape(),rewriter.getIndexType()),
resultTy.getNumElements() * sizeof(intptr_t), /* hacky way to get target-dependent size of IndexType */
rewriter
);
// Views
MemRefView vRes(result), vArg(arg);
// Index Values
IndexedValue iRes(result), iArg(arg);
// Bounds Index Handles
auto resLbs = vRes.getLbs();
auto resUbs = vRes.getUbs();
auto argLbs = vArg.getLbs();
auto argUbs = vArg.getUbs();
{ {
// Loop induction vars T argmin = cast<T>(op);
auto ivs = IndexHandle::makeIndexHandles(vRes.rank()); auto loc = argmin.getLoc();
auto pivs = IndexHandle::makeIndexHandlePointers(ivs); auto axesAttr = argmin.axes();
// Steps
auto steps = vRes.getSteps(); NGRAPH_CHECK(axesAttr.size() == 1, "Index Reduction op should have one reduction axis");
auto initVal = vArg.lb(axis); Attribute axisAttr = *axesAttr.begin();
// clang-format off unsigned axis = axisAttr.dyn_cast<IntegerAttr>().getInt();
NGRAPH_CHECK(operands.size() == 1 && operands[0] != nullptr,
"Expected one non-null operand in Index Reduction op");
// Retrieve/generate Values for operands and result.
ScopedContext scope(rewriter, loc);
Value* arg = operands[0];
auto arg_type = arg->getType().cast<MemRefType>();
Value* finalResult = m_pass.buildOutputDefs(op, rewriter)[0];
Type type = argmin.getResult()->getType();
NGTensorType resultTy = type.cast<NGTensorType>();
// MLIR doesn't support Index to/from Integer type-conversion
// We have to store our result in an IndexType tensor and call-back to a type-conversion routine in nGraph
// TODO: Fix this once MLIR provides explicit cast operations.
Value* result = m_pass.createTempTensor(
rewriter.getMemRefType(resultTy.getShape(), rewriter.getIndexType()),
resultTy.getNumElements() *
sizeof(intptr_t), /* hacky way to get target-dependent size of IndexType */
rewriter);
// Views
MemRefView vRes(result), vArg(arg);
// Index Values
IndexedValue iRes(result), iArg(arg);
// Bounds Index Handles
auto resLbs = vRes.getLbs();
auto resUbs = vRes.getUbs();
auto argLbs = vArg.getLbs();
auto argUbs = vArg.getUbs();
{
// Loop induction vars
auto ivs = IndexHandle::makeIndexHandles(vRes.rank());
auto pivs = IndexHandle::makeIndexHandlePointers(ivs);
// Steps
auto steps = vRes.getSteps();
auto initVal = vArg.lb(axis);
// clang-format off
LoopNestBuilder(pivs, resLbs, resUbs, steps)( LoopNestBuilder(pivs, resLbs, resUbs, steps)(
// single stmt body // single stmt body
[&] { [&] {
......
...@@ -19,8 +19,8 @@ ...@@ -19,8 +19,8 @@
#include "ngraph/assertion.hpp" #include "ngraph/assertion.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/argmax.hpp" #include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/dot.hpp" #include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp" #include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
......
...@@ -85,16 +85,16 @@ NGRAPH_TEST(${BACKEND_NAME}, argmin_3D_i32) ...@@ -85,16 +85,16 @@ NGRAPH_TEST(${BACKEND_NAME}, argmin_3D_i32)
// Create some tensors for input/output // Create some tensors for input/output
auto a = backend->create_tensor(element::i32, shape); auto a = backend->create_tensor(element::i32, shape);
copy_data(a, test::NDArray<int,3>({ copy_data(a,
{{12,2,10,9},{3,5,0,8},{7,9,1,5}}, test::NDArray<int, 3>({{{12, 2, 10, 9}, {3, 5, 0, 8}, {7, 9, 1, 5}},
{{7,2,4,10},{6,10,2,2},{12,1,1,1}}, {{7, 2, 4, 10}, {6, 10, 2, 2}, {12, 1, 1, 1}},
{{10,2,2,4},{1,5,5,1},{7,12,2,2}} {{10, 2, 2, 4}, {1, 5, 5, 1}, {7, 12, 2, 2}}})
}).get_vector()); .get_vector());
auto result = backend->create_tensor(element::i32, rshape); auto result = backend->create_tensor(element::i32, rshape);
auto handle = backend->compile(f); auto handle = backend->compile(f);
handle->call_with_validate({result}, {a}); handle->call_with_validate({result}, {a});
EXPECT_EQ((vector<int>{1, 0, 1, 2, 1, 2, 2, 2, 1, 0, 0,1}), read_vector<int>(result)); EXPECT_EQ((vector<int>{1, 0, 1, 2, 1, 2, 2, 2, 1, 0, 0, 1}), read_vector<int>(result));
} }
NGRAPH_TEST(${BACKEND_NAME}, argmin_3D_i64) NGRAPH_TEST(${BACKEND_NAME}, argmin_3D_i64)
...@@ -108,19 +108,18 @@ NGRAPH_TEST(${BACKEND_NAME}, argmin_3D_i64) ...@@ -108,19 +108,18 @@ NGRAPH_TEST(${BACKEND_NAME}, argmin_3D_i64)
// Create some tensors for input/output // Create some tensors for input/output
auto a = backend->create_tensor(element::i32, shape); auto a = backend->create_tensor(element::i32, shape);
copy_data(a, test::NDArray<int,3>({ copy_data(a,
{{12,2,10,9},{3,5,0,8},{7,9,1,5}}, test::NDArray<int, 3>({{{12, 2, 10, 9}, {3, 5, 0, 8}, {7, 9, 1, 5}},
{{7,2,4,10},{6,10,2,2},{12,1,1,1}}, {{7, 2, 4, 10}, {6, 10, 2, 2}, {12, 1, 1, 1}},
{{10,2,2,4},{1,5,5,1},{7,12,2,2}} {{10, 2, 2, 4}, {1, 5, 5, 1}, {7, 12, 2, 2}}})
}).get_vector()); .get_vector());
auto result = backend->create_tensor(element::i64, rshape); auto result = backend->create_tensor(element::i64, rshape);
auto handle = backend->compile(f); auto handle = backend->compile(f);
handle->call_with_validate({result}, {a}); handle->call_with_validate({result}, {a});
EXPECT_EQ((vector<int64_t>{1, 0, 1, 2, 1, 2, 2, 2, 1, 0, 0,1}), read_vector<int64_t>(result)); EXPECT_EQ((vector<int64_t>{1, 0, 1, 2, 1, 2, 2, 2, 1, 0, 0, 1}), read_vector<int64_t>(result));
} }
NGRAPH_TEST(${BACKEND_NAME}, argmin_4D_i64) NGRAPH_TEST(${BACKEND_NAME}, argmin_4D_i64)
{ {
Shape shape{2, 2, 5, 5}; // NCHW ->(0,1,2,3) Shape shape{2, 2, 5, 5}; // NCHW ->(0,1,2,3)
...@@ -130,28 +129,26 @@ NGRAPH_TEST(${BACKEND_NAME}, argmin_4D_i64) ...@@ -130,28 +129,26 @@ NGRAPH_TEST(${BACKEND_NAME}, argmin_4D_i64)
auto backend = runtime::Backend::create("${BACKEND_NAME}"); auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output // Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape); auto a = backend->create_tensor(element::f32, shape);
copy_data(a, copy_data(
test::NDArray<int, 4>({{{{3, 1, 1, 2, 105}, a,
{0, 3, 2, 1, 2}, test::NDArray<int, 4>(
{2, 4, 2, 0, 1}, {{{{3, 1, 1, 2, 105},
{2, 5, 1, 1, 22}, {0, 3, 2, 1, 2},
{5, 2, 1, 7, 5}}, {2, 4, 2, 0, 1},
{{3, 1, 2, 2, 1}, {2, 5, 1, 1, 22},
{1, 7, 3, 8, 1}, {5, 2, 1, 7, 5}},
{2, 10, 1, 3, 2}, {{3, 1, 2, 2, 1},
{3, 1, 0, 0, 6}, {1, 7, 3, 8, 1},
{2, 0, 0, 0, 0}}}, {2, 10, 1, 3, 2},
{{{0, 2, 1, 1, 0}, {3, 1, 0, 0, 6},
{0, 0, 0, 0, 1}, {2, 0, 0, 0, 0}}},
{0, 0, 1, 0, 3}, {{{0, 2, 1, 1, 0}, {0, 0, 0, 0, 1}, {0, 0, 1, 0, 3}, {2, 0, 0, 3, 0}, {0, 0, 0, 0, 1}},
{2, 0, 0, 3, 0}, {{2, 1, 0, 0, 1},
{0, 0, 0, 0, 1}}, {0, 2, 0, 0, 0},
{{2, 1, 0, 0, 1}, {1, 1, 2, 0, 2},
{0, 2, 0, 0, 0}, {1, 1, 1, 0, 1},
{1, 1, 2, 0, 2}, {1, 0, 0, 0, 2}}}})
{1, 1, 1, 0, 1}, .get_vector());
{1, 0, 0, 0, 2}}}})
.get_vector());
auto result = backend->create_tensor(element::i64, rshape); auto result = backend->create_tensor(element::i64, rshape);
auto handle = backend->compile(f); auto handle = backend->compile(f);
handle->call_with_validate({result}, {a}); handle->call_with_validate({result}, {a});
...@@ -292,11 +289,11 @@ NGRAPH_TEST(${BACKEND_NAME}, argmax_3D_i32) ...@@ -292,11 +289,11 @@ NGRAPH_TEST(${BACKEND_NAME}, argmax_3D_i32)
// Create some tensors for input/output // Create some tensors for input/output
auto a = backend->create_tensor(element::i32, shape); auto a = backend->create_tensor(element::i32, shape);
copy_data(a, test::NDArray<int,3>({ copy_data(a,
{{12,2,10,9},{3,5,0,8},{7,9,1,5}}, test::NDArray<int, 3>({{{12, 2, 10, 9}, {3, 5, 0, 8}, {7, 9, 1, 5}},
{{7,2,4,10},{6,10,2,2},{12,1,1,1}}, {{7, 2, 4, 10}, {6, 10, 2, 2}, {12, 1, 1, 1}},
{{10,2,2,4},{1,5,5,1},{7,12,2,2}} {{10, 2, 2, 4}, {1, 5, 5, 1}, {7, 12, 2, 2}}})
}).get_vector()); .get_vector());
auto result = backend->create_tensor(element::i32, rshape); auto result = backend->create_tensor(element::i32, rshape);
auto handle = backend->compile(f); auto handle = backend->compile(f);
...@@ -315,11 +312,11 @@ NGRAPH_TEST(${BACKEND_NAME}, argmax_3D_i64) ...@@ -315,11 +312,11 @@ NGRAPH_TEST(${BACKEND_NAME}, argmax_3D_i64)
// Create some tensors for input/output // Create some tensors for input/output
auto a = backend->create_tensor(element::i32, shape); auto a = backend->create_tensor(element::i32, shape);
copy_data(a, test::NDArray<int,3>({ copy_data(a,
{{12,2,10,9},{3,5,0,8},{7,9,1,5}}, test::NDArray<int, 3>({{{12, 2, 10, 9}, {3, 5, 0, 8}, {7, 9, 1, 5}},
{{7,2,4,10},{6,10,2,2},{12,1,1,1}}, {{7, 2, 4, 10}, {6, 10, 2, 2}, {12, 1, 1, 1}},
{{10,2,2,4},{1,5,5,1},{7,12,2,2}} {{10, 2, 2, 4}, {1, 5, 5, 1}, {7, 12, 2, 2}}})
}).get_vector()); .get_vector());
auto result = backend->create_tensor(element::i64, rshape); auto result = backend->create_tensor(element::i64, rshape);
auto handle = backend->compile(f); auto handle = backend->compile(f);
...@@ -327,7 +324,6 @@ NGRAPH_TEST(${BACKEND_NAME}, argmax_3D_i64) ...@@ -327,7 +324,6 @@ NGRAPH_TEST(${BACKEND_NAME}, argmax_3D_i64)
EXPECT_EQ((vector<int64_t>{0, 2, 0, 0, 2, 1, 0, 0, 0, 2, 1, 0}), read_vector<int64_t>(result)); EXPECT_EQ((vector<int64_t>{0, 2, 0, 0, 2, 1, 0, 0, 0, 2, 1, 0}), read_vector<int64_t>(result));
} }
NGRAPH_TEST(${BACKEND_NAME}, argmax_4D_i64) NGRAPH_TEST(${BACKEND_NAME}, argmax_4D_i64)
{ {
Shape shape{2, 2, 5, 5}; // NCHW ->(0,1,2,3) Shape shape{2, 2, 5, 5}; // NCHW ->(0,1,2,3)
...@@ -337,28 +333,26 @@ NGRAPH_TEST(${BACKEND_NAME}, argmax_4D_i64) ...@@ -337,28 +333,26 @@ NGRAPH_TEST(${BACKEND_NAME}, argmax_4D_i64)
auto backend = runtime::Backend::create("${BACKEND_NAME}"); auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output // Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape); auto a = backend->create_tensor(element::f32, shape);
copy_data(a, copy_data(
test::NDArray<int, 4>({{{{3, 1, 1, 2, 105}, a,
{0, 3, 2, 1, 2}, test::NDArray<int, 4>(
{2, 4, 2, 0, 1}, {{{{3, 1, 1, 2, 105},
{2, 5, 1, 1, 22}, {0, 3, 2, 1, 2},
{5, 2, 1, 7, 5}}, {2, 4, 2, 0, 1},
{{3, 1, 2, 2, 1}, {2, 5, 1, 1, 22},
{1, 7, 3, 8, 1}, {5, 2, 1, 7, 5}},
{2, 10, 1, 3, 2}, {{3, 1, 2, 2, 1},
{3, 1, 0, 0, 6}, {1, 7, 3, 8, 1},
{2, 0, 0, 0, 0}}}, {2, 10, 1, 3, 2},
{{{0, 2, 1, 1, 0}, {3, 1, 0, 0, 6},
{0, 0, 0, 0, 1}, {2, 0, 0, 0, 0}}},
{0, 0, 1, 0, 3}, {{{0, 2, 1, 1, 0}, {0, 0, 0, 0, 1}, {0, 0, 1, 0, 3}, {2, 0, 0, 3, 0}, {0, 0, 0, 0, 1}},
{2, 0, 0, 3, 0}, {{2, 1, 0, 0, 1},
{0, 0, 0, 0, 1}}, {0, 2, 0, 0, 0},
{{2, 1, 0, 0, 1}, {1, 1, 2, 0, 2},
{0, 2, 0, 0, 0}, {1, 1, 1, 0, 1},
{1, 1, 2, 0, 2}, {1, 0, 0, 0, 2}}}})
{1, 1, 1, 0, 1}, .get_vector());
{1, 0, 0, 0, 2}}}})
.get_vector());
auto result = backend->create_tensor(element::i64, rshape); auto result = backend->create_tensor(element::i64, rshape);
auto handle = backend->compile(f); auto handle = backend->compile(f);
handle->call_with_validate({result}, {a}); handle->call_with_validate({result}, {a});
...@@ -366,7 +360,6 @@ NGRAPH_TEST(${BACKEND_NAME}, argmax_4D_i64) ...@@ -366,7 +360,6 @@ NGRAPH_TEST(${BACKEND_NAME}, argmax_4D_i64)
read_vector<int64_t>(result)); read_vector<int64_t>(result));
} }
NGRAPH_TEST(${BACKEND_NAME}, argmax_3D_axis_0) // Along Channels NGRAPH_TEST(${BACKEND_NAME}, argmax_3D_axis_0) // Along Channels
{ {
Shape shape{3, 4, 2}; // CHW ->(0,1,2) Shape shape{3, 4, 2}; // CHW ->(0,1,2)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment