Commit 955361bb authored by nmostafa's avatar nmostafa

Fix rebase issues. Style-apply

parent 4ef010fc
...@@ -24,8 +24,8 @@ ...@@ -24,8 +24,8 @@
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/argmax.hpp" #include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/dot.hpp" #include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp" #include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/util/index_reduction.hpp" #include "ngraph/op/util/index_reduction.hpp"
...@@ -287,13 +287,13 @@ mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Add) ...@@ -287,13 +287,13 @@ mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::Add)
return compiler.create_binary_op<mlir::NGAddOp>(ng_node); return compiler.create_binary_op<mlir::NGAddOp>(ng_node);
} }
template<> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMin) mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMin)
{ {
return compiler.create_index_reduction<mlir::NGArgMinRedOp>(ng_node); return compiler.create_index_reduction<mlir::NGArgMinRedOp>(ng_node);
} }
template<> template <>
mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMax) mlir::Value* MLIRCompiler::COMPILE_OP_DECL(ngraph::op::ArgMax)
{ {
return compiler.create_index_reduction<mlir::NGArgMaxRedOp>(ng_node); return compiler.create_index_reduction<mlir::NGArgMaxRedOp>(ng_node);
...@@ -332,7 +332,7 @@ void MLIRCompiler::create_return() ...@@ -332,7 +332,7 @@ void MLIRCompiler::create_return()
m_builder->create<mlir::NGReturnOp>(mlir::UnknownLoc::get(&m_context), value_list); m_builder->create<mlir::NGReturnOp>(mlir::UnknownLoc::get(&m_context), value_list);
} }
template<typename RedOp> template <typename RedOp>
mlir::Value* MLIRCompiler::create_index_reduction(const ngraph::Node* ng_node) mlir::Value* MLIRCompiler::create_index_reduction(const ngraph::Node* ng_node)
{ {
auto* idx_red = static_cast<const ngraph::op::util::IndexReduction*>(ng_node); auto* idx_red = static_cast<const ngraph::op::util::IndexReduction*>(ng_node);
...@@ -344,10 +344,8 @@ mlir::Value* MLIRCompiler::create_index_reduction(const ngraph::Node* ng_node) ...@@ -344,10 +344,8 @@ mlir::Value* MLIRCompiler::create_index_reduction(const ngraph::Node* ng_node)
mlir::ArrayAttr red_axes_attr = m_builder->getI64ArrayAttr({(int64_t)red_axis}); mlir::ArrayAttr red_axes_attr = m_builder->getI64ArrayAttr({(int64_t)red_axis});
return m_builder return m_builder
->create<RedOp>(mlir::UnknownLoc::get(&m_context), ->create<RedOp>(
get_mlir_type(ng_node), mlir::UnknownLoc::get(&m_context), get_mlir_type(ng_node), arg_val, red_axes_attr)
arg_val,
red_axes_attr)
.getResult(); .getResult();
} }
// Binds MLIR function arguments to the proper values. This includes externally allocated tensors // Binds MLIR function arguments to the proper values. This includes externally allocated tensors
...@@ -409,7 +407,7 @@ void MLIRCompiler::execute() ...@@ -409,7 +407,7 @@ void MLIRCompiler::execute()
if (char* opt_level_str = std::getenv("NGRAPH_MLIR_OPT_LEVEL")) if (char* opt_level_str = std::getenv("NGRAPH_MLIR_OPT_LEVEL"))
{ {
opt_level = std::stoi(opt_level_str); opt_level = std::stoi(opt_level_str);
NGRAPH_CHECK(opt_level >=0 && opt_level <= 3 , "Invalid optimization level"); NGRAPH_CHECK(opt_level >= 0 && opt_level <= 3, "Invalid optimization level");
} }
// Create an MLIR execution engine. We use a null MLIR pass manager for now to make sure we // Create an MLIR execution engine. We use a null MLIR pass manager for now to make sure we
// don't run MLIR passes that were already run. We also pass a default transformer to run // don't run MLIR passes that were already run. We also pass a default transformer to run
......
...@@ -108,7 +108,7 @@ namespace ngraph ...@@ -108,7 +108,7 @@ namespace ngraph
template <typename BinOp> template <typename BinOp>
mlir::Value* create_binary_op(const ngraph::Node* ng_node); mlir::Value* create_binary_op(const ngraph::Node* ng_node);
template<typename RedOp> template <typename RedOp>
mlir::Value* create_index_reduction(const ngraph::Node* ng_node); mlir::Value* create_index_reduction(const ngraph::Node* ng_node);
void create_return(); void create_return();
......
...@@ -235,7 +235,7 @@ namespace mlir ...@@ -235,7 +235,7 @@ namespace mlir
return floatType.getIntOrFloatBitWidth(); return floatType.getIntOrFloatBitWidth();
if (NGBoolType boolType = type.dyn_cast<NGBoolType>()) if (NGBoolType boolType = type.dyn_cast<NGBoolType>())
return boolType.getWidth(); return boolType.getWidth();
NGRAPH_FAIL() << "Unknown type"; NGRAPH_CHECK(false, "Unknown type");
return -1; return -1;
} }
/// Get number of elements /// Get number of elements
......
...@@ -14,21 +14,24 @@ ...@@ -14,21 +14,24 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include <mlir/ExecutionEngine/MemRefUtils.h>
#include <stdint.h> #include <stdint.h>
#include "ngraph/ngraph_visibility.hpp" #include "ngraph/ngraph_visibility.hpp"
#include <mlir/ExecutionEngine/MemRefUtils.h>
/// Call back to copy Index tensor to Int tensor /// Call back to copy Index tensor to Int tensor
/// Can handle int tensors of bitwidth 8, 16, 32 and 64 /// Can handle int tensors of bitwidth 8, 16, 32 and 64
/// Index width is always intptr_t /// Index width is always intptr_t
extern "C" NGRAPH_API void __mlir_convert_index_to_int(mlir::StaticFloatMemRef dst, mlir::StaticFloatMemRef src, size_t numElements, size_t intWidth) extern "C" NGRAPH_API void __mlir_convert_index_to_int(mlir::StaticFloatMemRef dst,
mlir::StaticFloatMemRef src,
size_t numElements,
size_t intWidth)
{ {
size_t indexSize = sizeof(intptr_t); size_t indexSize = sizeof(intptr_t);
auto pSrc = reinterpret_cast<intptr_t*>(src.data); auto pSrc = reinterpret_cast<intptr_t*>(src.data);
auto pDst = reinterpret_cast<char*>(dst.data); auto pDst = reinterpret_cast<char*>(dst.data);
for (auto i = 0; i < numElements; i++) for (auto i = 0; i < numElements; i++)
{ {
switch(intWidth) switch (intWidth)
{ {
case 8: case 8:
*pDst = static_cast<char>(pSrc[i]); *pDst = static_cast<char>(pSrc[i]);
......
...@@ -46,8 +46,12 @@ namespace ...@@ -46,8 +46,12 @@ namespace
#include "op_lowerers.inc" #include "op_lowerers.inc"
// Helpers // Helpers
template<typename RedOp> template <typename RedOp>
void lowerIndexReduction(Operation* op, ArrayRef<Value*> operands, PatternRewriter& rewriter, DialectLoweringPass& m_pass, bool isMin); void lowerIndexReduction(Operation* op,
ArrayRef<Value*> operands,
PatternRewriter& rewriter,
DialectLoweringPass& m_pass,
bool isMin);
/// Use Dialect Converson Framework /// Use Dialect Converson Framework
class DialectLowerer : public DialectConversion class DialectLowerer : public DialectConversion
...@@ -94,6 +98,7 @@ namespace ...@@ -94,6 +98,7 @@ namespace
ArrayRef<Type> args, ArrayRef<Type> args,
ArrayRef<Type> output, ArrayRef<Type> output,
PatternRewriter& rewriter); PatternRewriter& rewriter);
private: private:
void findOutputValues(); void findOutputValues();
void processFakeInstrs(); void processFakeInstrs();
...@@ -189,14 +194,17 @@ namespace ...@@ -189,14 +194,17 @@ namespace
else else
{ {
auto tensorType = origResult->getType().cast<NGTensorType>(); auto tensorType = origResult->getType().cast<NGTensorType>();
auto newResult = createTempTensor(m_dialectLowerer.convertType(tensorType), tensorType.getSizeInBytes(), rewriter); auto newResult = createTempTensor(m_dialectLowerer.convertType(tensorType),
tensorType.getSizeInBytes(),
rewriter);
newResults.push_back(newResult); newResults.push_back(newResult);
} }
} }
return newResults; return newResults;
} }
Value* DialectLoweringPass::createTempTensor(Type type, unsigned size, PatternRewriter& rewriter) Value*
DialectLoweringPass::createTempTensor(Type type, unsigned size, PatternRewriter& rewriter)
{ {
auto callBackFunc = getCallDecl("__mlir_allocate", auto callBackFunc = getCallDecl("__mlir_allocate",
{rewriter.getIndexType(), rewriter.getIndexType()}, {rewriter.getIndexType(), rewriter.getIndexType()},
...@@ -206,8 +214,7 @@ namespace ...@@ -206,8 +214,7 @@ namespace
insertMemMgrDef(&rewriter), /* pointer to mem manager */ insertMemMgrDef(&rewriter), /* pointer to mem manager */
rewriter.create<mlir::ConstantIndexOp>(rewriter.getUnknownLoc(), rewriter.create<mlir::ConstantIndexOp>(rewriter.getUnknownLoc(),
size)}; /* size to allocate */ size)}; /* size to allocate */
auto newTemp = auto newTemp = rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args)
rewriter.create<mlir::CallOp>(rewriter.getUnknownLoc(), callBackFunc, args)
.getResult(0); .getResult(0);
return newTemp; return newTemp;
} }
...@@ -424,14 +431,18 @@ namespace ...@@ -424,14 +431,18 @@ namespace
REWRITER(NGReturnOp) { rewriter.replaceOpWithNewOp<ReturnOp>(op); } REWRITER(NGReturnOp) { rewriter.replaceOpWithNewOp<ReturnOp>(op); }
#undef REWRITER #undef REWRITER
template<typename T> template <typename T>
void lowerIndexReduction(Operation* op, ArrayRef<Value*> operands, PatternRewriter& rewriter, DialectLoweringPass& m_pass, bool isMin) void lowerIndexReduction(Operation* op,
{ ArrayRef<Value*> operands,
PatternRewriter& rewriter,
DialectLoweringPass& m_pass,
bool isMin)
{
T argmin = cast<T>(op); T argmin = cast<T>(op);
auto loc = argmin.getLoc(); auto loc = argmin.getLoc();
auto axesAttr = argmin.axes(); auto axesAttr = argmin.axes();
NGRAPH_CHECK(axesAttr.size() == 1 , "Index Reduction op should have one reduction axis"); NGRAPH_CHECK(axesAttr.size() == 1, "Index Reduction op should have one reduction axis");
Attribute axisAttr = *axesAttr.begin(); Attribute axisAttr = *axesAttr.begin();
unsigned axis = axisAttr.dyn_cast<IntegerAttr>().getInt(); unsigned axis = axisAttr.dyn_cast<IntegerAttr>().getInt();
...@@ -450,10 +461,10 @@ void lowerIndexReduction(Operation* op, ArrayRef<Value*> operands, PatternRewrit ...@@ -450,10 +461,10 @@ void lowerIndexReduction(Operation* op, ArrayRef<Value*> operands, PatternRewrit
// We have to store our result in an IndexType tensor and call-back to a type-conversion routine in nGraph // We have to store our result in an IndexType tensor and call-back to a type-conversion routine in nGraph
// TODO: Fix this once MLIR provides explicit cast operations. // TODO: Fix this once MLIR provides explicit cast operations.
Value* result = m_pass.createTempTensor( Value* result = m_pass.createTempTensor(
rewriter.getMemRefType(resultTy.getShape(),rewriter.getIndexType()), rewriter.getMemRefType(resultTy.getShape(), rewriter.getIndexType()),
resultTy.getNumElements() * sizeof(intptr_t), /* hacky way to get target-dependent size of IndexType */ resultTy.getNumElements() *
rewriter sizeof(intptr_t), /* hacky way to get target-dependent size of IndexType */
); rewriter);
// Views // Views
MemRefView vRes(result), vArg(arg); MemRefView vRes(result), vArg(arg);
......
...@@ -19,8 +19,8 @@ ...@@ -19,8 +19,8 @@
#include "ngraph/assertion.hpp" #include "ngraph/assertion.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/argmax.hpp" #include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/dot.hpp" #include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp" #include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
......
...@@ -85,16 +85,16 @@ NGRAPH_TEST(${BACKEND_NAME}, argmin_3D_i32) ...@@ -85,16 +85,16 @@ NGRAPH_TEST(${BACKEND_NAME}, argmin_3D_i32)
// Create some tensors for input/output // Create some tensors for input/output
auto a = backend->create_tensor(element::i32, shape); auto a = backend->create_tensor(element::i32, shape);
copy_data(a, test::NDArray<int,3>({ copy_data(a,
{{12,2,10,9},{3,5,0,8},{7,9,1,5}}, test::NDArray<int, 3>({{{12, 2, 10, 9}, {3, 5, 0, 8}, {7, 9, 1, 5}},
{{7,2,4,10},{6,10,2,2},{12,1,1,1}}, {{7, 2, 4, 10}, {6, 10, 2, 2}, {12, 1, 1, 1}},
{{10,2,2,4},{1,5,5,1},{7,12,2,2}} {{10, 2, 2, 4}, {1, 5, 5, 1}, {7, 12, 2, 2}}})
}).get_vector()); .get_vector());
auto result = backend->create_tensor(element::i32, rshape); auto result = backend->create_tensor(element::i32, rshape);
auto handle = backend->compile(f); auto handle = backend->compile(f);
handle->call_with_validate({result}, {a}); handle->call_with_validate({result}, {a});
EXPECT_EQ((vector<int>{1, 0, 1, 2, 1, 2, 2, 2, 1, 0, 0,1}), read_vector<int>(result)); EXPECT_EQ((vector<int>{1, 0, 1, 2, 1, 2, 2, 2, 1, 0, 0, 1}), read_vector<int>(result));
} }
NGRAPH_TEST(${BACKEND_NAME}, argmin_3D_i64) NGRAPH_TEST(${BACKEND_NAME}, argmin_3D_i64)
...@@ -108,19 +108,18 @@ NGRAPH_TEST(${BACKEND_NAME}, argmin_3D_i64) ...@@ -108,19 +108,18 @@ NGRAPH_TEST(${BACKEND_NAME}, argmin_3D_i64)
// Create some tensors for input/output // Create some tensors for input/output
auto a = backend->create_tensor(element::i32, shape); auto a = backend->create_tensor(element::i32, shape);
copy_data(a, test::NDArray<int,3>({ copy_data(a,
{{12,2,10,9},{3,5,0,8},{7,9,1,5}}, test::NDArray<int, 3>({{{12, 2, 10, 9}, {3, 5, 0, 8}, {7, 9, 1, 5}},
{{7,2,4,10},{6,10,2,2},{12,1,1,1}}, {{7, 2, 4, 10}, {6, 10, 2, 2}, {12, 1, 1, 1}},
{{10,2,2,4},{1,5,5,1},{7,12,2,2}} {{10, 2, 2, 4}, {1, 5, 5, 1}, {7, 12, 2, 2}}})
}).get_vector()); .get_vector());
auto result = backend->create_tensor(element::i64, rshape); auto result = backend->create_tensor(element::i64, rshape);
auto handle = backend->compile(f); auto handle = backend->compile(f);
handle->call_with_validate({result}, {a}); handle->call_with_validate({result}, {a});
EXPECT_EQ((vector<int64_t>{1, 0, 1, 2, 1, 2, 2, 2, 1, 0, 0,1}), read_vector<int64_t>(result)); EXPECT_EQ((vector<int64_t>{1, 0, 1, 2, 1, 2, 2, 2, 1, 0, 0, 1}), read_vector<int64_t>(result));
} }
NGRAPH_TEST(${BACKEND_NAME}, argmin_4D_i64) NGRAPH_TEST(${BACKEND_NAME}, argmin_4D_i64)
{ {
Shape shape{2, 2, 5, 5}; // NCHW ->(0,1,2,3) Shape shape{2, 2, 5, 5}; // NCHW ->(0,1,2,3)
...@@ -130,8 +129,10 @@ NGRAPH_TEST(${BACKEND_NAME}, argmin_4D_i64) ...@@ -130,8 +129,10 @@ NGRAPH_TEST(${BACKEND_NAME}, argmin_4D_i64)
auto backend = runtime::Backend::create("${BACKEND_NAME}"); auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output // Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape); auto a = backend->create_tensor(element::f32, shape);
copy_data(a, copy_data(
test::NDArray<int, 4>({{{{3, 1, 1, 2, 105}, a,
test::NDArray<int, 4>(
{{{{3, 1, 1, 2, 105},
{0, 3, 2, 1, 2}, {0, 3, 2, 1, 2},
{2, 4, 2, 0, 1}, {2, 4, 2, 0, 1},
{2, 5, 1, 1, 22}, {2, 5, 1, 1, 22},
...@@ -141,11 +142,7 @@ NGRAPH_TEST(${BACKEND_NAME}, argmin_4D_i64) ...@@ -141,11 +142,7 @@ NGRAPH_TEST(${BACKEND_NAME}, argmin_4D_i64)
{2, 10, 1, 3, 2}, {2, 10, 1, 3, 2},
{3, 1, 0, 0, 6}, {3, 1, 0, 0, 6},
{2, 0, 0, 0, 0}}}, {2, 0, 0, 0, 0}}},
{{{0, 2, 1, 1, 0}, {{{0, 2, 1, 1, 0}, {0, 0, 0, 0, 1}, {0, 0, 1, 0, 3}, {2, 0, 0, 3, 0}, {0, 0, 0, 0, 1}},
{0, 0, 0, 0, 1},
{0, 0, 1, 0, 3},
{2, 0, 0, 3, 0},
{0, 0, 0, 0, 1}},
{{2, 1, 0, 0, 1}, {{2, 1, 0, 0, 1},
{0, 2, 0, 0, 0}, {0, 2, 0, 0, 0},
{1, 1, 2, 0, 2}, {1, 1, 2, 0, 2},
...@@ -292,11 +289,11 @@ NGRAPH_TEST(${BACKEND_NAME}, argmax_3D_i32) ...@@ -292,11 +289,11 @@ NGRAPH_TEST(${BACKEND_NAME}, argmax_3D_i32)
// Create some tensors for input/output // Create some tensors for input/output
auto a = backend->create_tensor(element::i32, shape); auto a = backend->create_tensor(element::i32, shape);
copy_data(a, test::NDArray<int,3>({ copy_data(a,
{{12,2,10,9},{3,5,0,8},{7,9,1,5}}, test::NDArray<int, 3>({{{12, 2, 10, 9}, {3, 5, 0, 8}, {7, 9, 1, 5}},
{{7,2,4,10},{6,10,2,2},{12,1,1,1}}, {{7, 2, 4, 10}, {6, 10, 2, 2}, {12, 1, 1, 1}},
{{10,2,2,4},{1,5,5,1},{7,12,2,2}} {{10, 2, 2, 4}, {1, 5, 5, 1}, {7, 12, 2, 2}}})
}).get_vector()); .get_vector());
auto result = backend->create_tensor(element::i32, rshape); auto result = backend->create_tensor(element::i32, rshape);
auto handle = backend->compile(f); auto handle = backend->compile(f);
...@@ -315,11 +312,11 @@ NGRAPH_TEST(${BACKEND_NAME}, argmax_3D_i64) ...@@ -315,11 +312,11 @@ NGRAPH_TEST(${BACKEND_NAME}, argmax_3D_i64)
// Create some tensors for input/output // Create some tensors for input/output
auto a = backend->create_tensor(element::i32, shape); auto a = backend->create_tensor(element::i32, shape);
copy_data(a, test::NDArray<int,3>({ copy_data(a,
{{12,2,10,9},{3,5,0,8},{7,9,1,5}}, test::NDArray<int, 3>({{{12, 2, 10, 9}, {3, 5, 0, 8}, {7, 9, 1, 5}},
{{7,2,4,10},{6,10,2,2},{12,1,1,1}}, {{7, 2, 4, 10}, {6, 10, 2, 2}, {12, 1, 1, 1}},
{{10,2,2,4},{1,5,5,1},{7,12,2,2}} {{10, 2, 2, 4}, {1, 5, 5, 1}, {7, 12, 2, 2}}})
}).get_vector()); .get_vector());
auto result = backend->create_tensor(element::i64, rshape); auto result = backend->create_tensor(element::i64, rshape);
auto handle = backend->compile(f); auto handle = backend->compile(f);
...@@ -327,7 +324,6 @@ NGRAPH_TEST(${BACKEND_NAME}, argmax_3D_i64) ...@@ -327,7 +324,6 @@ NGRAPH_TEST(${BACKEND_NAME}, argmax_3D_i64)
EXPECT_EQ((vector<int64_t>{0, 2, 0, 0, 2, 1, 0, 0, 0, 2, 1, 0}), read_vector<int64_t>(result)); EXPECT_EQ((vector<int64_t>{0, 2, 0, 0, 2, 1, 0, 0, 0, 2, 1, 0}), read_vector<int64_t>(result));
} }
NGRAPH_TEST(${BACKEND_NAME}, argmax_4D_i64) NGRAPH_TEST(${BACKEND_NAME}, argmax_4D_i64)
{ {
Shape shape{2, 2, 5, 5}; // NCHW ->(0,1,2,3) Shape shape{2, 2, 5, 5}; // NCHW ->(0,1,2,3)
...@@ -337,8 +333,10 @@ NGRAPH_TEST(${BACKEND_NAME}, argmax_4D_i64) ...@@ -337,8 +333,10 @@ NGRAPH_TEST(${BACKEND_NAME}, argmax_4D_i64)
auto backend = runtime::Backend::create("${BACKEND_NAME}"); auto backend = runtime::Backend::create("${BACKEND_NAME}");
// Create some tensors for input/output // Create some tensors for input/output
auto a = backend->create_tensor(element::f32, shape); auto a = backend->create_tensor(element::f32, shape);
copy_data(a, copy_data(
test::NDArray<int, 4>({{{{3, 1, 1, 2, 105}, a,
test::NDArray<int, 4>(
{{{{3, 1, 1, 2, 105},
{0, 3, 2, 1, 2}, {0, 3, 2, 1, 2},
{2, 4, 2, 0, 1}, {2, 4, 2, 0, 1},
{2, 5, 1, 1, 22}, {2, 5, 1, 1, 22},
...@@ -348,11 +346,7 @@ NGRAPH_TEST(${BACKEND_NAME}, argmax_4D_i64) ...@@ -348,11 +346,7 @@ NGRAPH_TEST(${BACKEND_NAME}, argmax_4D_i64)
{2, 10, 1, 3, 2}, {2, 10, 1, 3, 2},
{3, 1, 0, 0, 6}, {3, 1, 0, 0, 6},
{2, 0, 0, 0, 0}}}, {2, 0, 0, 0, 0}}},
{{{0, 2, 1, 1, 0}, {{{0, 2, 1, 1, 0}, {0, 0, 0, 0, 1}, {0, 0, 1, 0, 3}, {2, 0, 0, 3, 0}, {0, 0, 0, 0, 1}},
{0, 0, 0, 0, 1},
{0, 0, 1, 0, 3},
{2, 0, 0, 3, 0},
{0, 0, 0, 0, 1}},
{{2, 1, 0, 0, 1}, {{2, 1, 0, 0, 1},
{0, 2, 0, 0, 0}, {0, 2, 0, 0, 0},
{1, 1, 2, 0, 2}, {1, 1, 2, 0, 2},
...@@ -366,7 +360,6 @@ NGRAPH_TEST(${BACKEND_NAME}, argmax_4D_i64) ...@@ -366,7 +360,6 @@ NGRAPH_TEST(${BACKEND_NAME}, argmax_4D_i64)
read_vector<int64_t>(result)); read_vector<int64_t>(result));
} }
NGRAPH_TEST(${BACKEND_NAME}, argmax_3D_axis_0) // Along Channels NGRAPH_TEST(${BACKEND_NAME}, argmax_3D_axis_0) // Along Channels
{ {
Shape shape{3, 4, 2}; // CHW ->(0,1,2) Shape shape{3, 4, 2}; // CHW ->(0,1,2)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment