Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
e762203e
Unverified
Commit
e762203e
authored
Jul 10, 2019
by
Scott Cyphers
Committed by
GitHub
Jul 10, 2019
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #3021 from NervanaSystems/dcaballe/argmin
[MLIR] Add ArgMin/ArgMax lowering support
parents
150250b0
a3768ee4
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
445 additions
and
35 deletions
+445
-35
compiler.cpp
src/contrib/mlir/compiler.cpp
+59
-10
compiler.hpp
src/contrib/mlir/compiler.hpp
+5
-0
ops.cpp
src/contrib/mlir/dialect/ops.cpp
+1
-1
type.hpp
src/contrib/mlir/dialect/type.hpp
+19
-3
lowerer.cpp
src/contrib/mlir/lowerer.cpp
+141
-21
op_lowerers.inc
src/contrib/mlir/op_lowerers.inc
+2
-0
ops_supported.inc
src/contrib/mlir/ops_supported.inc
+2
-0
mlir_subgraph_extraction.cpp
src/contrib/mlir/pass/mlir_subgraph_extraction.cpp
+10
-0
unit_test.manifest
src/ngraph/runtime/plaidml/unit_test.manifest
+4
-0
backend_arg_reduce.in.cpp
test/backend_arg_reduce.in.cpp
+202
-0
No files found.
src/contrib/mlir/compiler.cpp
View file @
e762203e
...
@@ -24,8 +24,11 @@
...
@@ -24,8 +24,11 @@
#include "ngraph/graph_util.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/util/index_reduction.hpp"
#include "ngraph/type/element_type.hpp"
#include "ngraph/type/element_type.hpp"
#include <llvm/ADT/STLExtras.h>
#include <llvm/ADT/STLExtras.h>
...
@@ -110,12 +113,12 @@ void MLIRCompiler::build_ng_dialect_module()
...
@@ -110,12 +113,12 @@ void MLIRCompiler::build_ng_dialect_module()
for
(
auto
input
:
kernel_inputs
)
for
(
auto
input
:
kernel_inputs
)
{
{
args_type_list
.
push_back
(
get_mlir_type
(
input
->
get_output_tensor_ptr
()
.
get
()));
args_type_list
.
push_back
(
get_mlir_type
(
input
.
get
()));
}
}
for
(
auto
output
:
kernel_outputs
)
for
(
auto
output
:
kernel_outputs
)
{
{
result_type_list
.
push_back
(
get_mlir_type
(
output
->
get_output_tensor_ptr
()
.
get
()));
result_type_list
.
push_back
(
get_mlir_type
(
output
.
get
()));
}
}
auto
func_type
=
mlir
::
FunctionType
::
get
(
args_type_list
,
result_type_list
,
&
m_context
);
auto
func_type
=
mlir
::
FunctionType
::
get
(
args_type_list
,
result_type_list
,
&
m_context
);
...
@@ -146,17 +149,23 @@ void MLIRCompiler::build_ng_dialect_module()
...
@@ -146,17 +149,23 @@ void MLIRCompiler::build_ng_dialect_module()
dump_mlir_module
(
"nGraph Dialect Dump:"
);
dump_mlir_module
(
"nGraph Dialect Dump:"
);
}
}
// Converts an nGraph Tensor into an MLIR tensor type, including the conversion of the Tensor's
// Converts nGraph shape \p ng_shape to MLIR shape \p mlir_shape.
// element type.
static
void
get_mlir_shape
(
ngraph
::
Shape
ng_shape
,
llvm
::
SmallVectorImpl
<
int64_t
>&
mlir_shape
)
mlir
::
Type
MLIRCompiler
::
get_mlir_type
(
const
descriptor
::
Tensor
*
tensor
)
{
{
SmallVector
<
int64_t
,
4
>
shape
;
for
(
auto
dim
:
ng_shape
)
for
(
auto
d
:
tensor
->
get_shape
())
{
{
shape
.
push_back
(
d
);
mlir_shape
.
push_back
(
dim
);
}
}
}
return
mlir
::
NGTensorType
::
get
(
&
m_context
,
get_mlir_type
(
tensor
->
get_element_type
()),
shape
);
// Converts an nGraph Tensor into an MLIR tensor type, including the conversion of the Tensor's
// element type.
mlir
::
Type
MLIRCompiler
::
get_mlir_type
(
const
descriptor
::
Tensor
*
tensor
)
{
SmallVector
<
int64_t
,
4
>
mlir_shape
;
get_mlir_shape
(
tensor
->
get_shape
(),
mlir_shape
);
return
mlir
::
NGTensorType
::
get
(
&
m_context
,
get_mlir_type
(
tensor
->
get_element_type
()),
mlir_shape
);
}
}
// Converts an nGraph element type into an MLIR type.
// Converts an nGraph element type into an MLIR type.
...
@@ -195,6 +204,12 @@ mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type)
...
@@ -195,6 +204,12 @@ mlir::Type MLIRCompiler::get_mlir_type(const element::Type& type)
#endif
#endif
}
}
mlir
::
Type
MLIRCompiler
::
get_mlir_type
(
const
ngraph
::
Node
*
node
)
{
descriptor
::
Tensor
*
out_tensor
=
node
->
get_output_tensor_ptr
().
get
();
return
get_mlir_type
(
out_tensor
);
}
void
MLIRCompiler
::
update_tensor_value
(
descriptor
::
Tensor
*
tensor
,
mlir
::
Value
*
value
)
void
MLIRCompiler
::
update_tensor_value
(
descriptor
::
Tensor
*
tensor
,
mlir
::
Value
*
value
)
{
{
NGRAPH_CHECK
(
m_tensor_to_value_map
.
find
(
tensor
)
==
m_tensor_to_value_map
.
end
(),
NGRAPH_CHECK
(
m_tensor_to_value_map
.
find
(
tensor
)
==
m_tensor_to_value_map
.
end
(),
...
@@ -280,6 +295,17 @@ namespace ngraph
...
@@ -280,6 +295,17 @@ namespace ngraph
return
compiler
.
create_binary_op
<
mlir
::
NGAddOp
>
(
ng_node
);
return
compiler
.
create_binary_op
<
mlir
::
NGAddOp
>
(
ng_node
);
}
}
template
<>
mlir
::
Value
*
MLIRCompiler
::
COMPILE_OP_DECL
(
ngraph
::
op
::
ArgMax
)
{
return
compiler
.
create_index_reduction
<
mlir
::
NGArgMaxRedOp
>
(
ng_node
);
}
template
<>
mlir
::
Value
*
MLIRCompiler
::
COMPILE_OP_DECL
(
ngraph
::
op
::
ArgMin
)
{
return
compiler
.
create_index_reduction
<
mlir
::
NGArgMinRedOp
>
(
ng_node
);
}
template
<>
template
<>
mlir
::
Value
*
MLIRCompiler
::
COMPILE_OP_DECL
(
ngraph
::
op
::
Dot
)
mlir
::
Value
*
MLIRCompiler
::
COMPILE_OP_DECL
(
ngraph
::
op
::
Dot
)
{
{
...
@@ -316,6 +342,22 @@ void MLIRCompiler::create_return()
...
@@ -316,6 +342,22 @@ void MLIRCompiler::create_return()
m_builder
->
create
<
mlir
::
NGReturnOp
>
(
mlir
::
UnknownLoc
::
get
(
&
m_context
),
value_list
);
m_builder
->
create
<
mlir
::
NGReturnOp
>
(
mlir
::
UnknownLoc
::
get
(
&
m_context
),
value_list
);
}
}
template
<
typename
RedOp
>
mlir
::
Value
*
MLIRCompiler
::
create_index_reduction
(
const
ngraph
::
Node
*
ng_node
)
{
auto
*
idx_red
=
static_cast
<
const
ngraph
::
op
::
util
::
IndexReduction
*>
(
ng_node
);
auto
arg
=
idx_red
->
get_argument
(
0
);
size_t
red_axis
=
idx_red
->
get_reduction_axis
();
mlir
::
Value
*
arg_val
=
get_tensor_value
(
arg
->
get_output_tensor_ptr
().
get
()).
m_value
;
mlir
::
ArrayAttr
red_axes_attr
=
m_builder
->
getI64ArrayAttr
({(
int64_t
)
red_axis
});
return
m_builder
->
create
<
RedOp
>
(
mlir
::
UnknownLoc
::
get
(
&
m_context
),
get_mlir_type
(
ng_node
),
arg_val
,
red_axes_attr
)
.
getResult
();
}
// Binds MLIR function arguments to the proper values. This includes externally allocated tensors
// Binds MLIR function arguments to the proper values. This includes externally allocated tensors
// helpers to be used inside the function.
// helpers to be used inside the function.
void
MLIRCompiler
::
bind_arguments
()
void
MLIRCompiler
::
bind_arguments
()
...
@@ -376,10 +418,17 @@ void MLIRCompiler::execute()
...
@@ -376,10 +418,17 @@ void MLIRCompiler::execute()
llvm
::
InitializeNativeTarget
();
llvm
::
InitializeNativeTarget
();
llvm
::
InitializeNativeTargetAsmPrinter
();
llvm
::
InitializeNativeTargetAsmPrinter
();
unsigned
opt_level
=
3
;
if
(
char
*
opt_level_str
=
std
::
getenv
(
"NGRAPH_MLIR_OPT_LEVEL"
))
{
opt_level
=
std
::
stoi
(
opt_level_str
);
NGRAPH_CHECK
(
opt_level
>=
0
&&
opt_level
<=
3
,
"Invalid optimization level"
);
}
// Create an MLIR execution engine. We use a null MLIR pass manager for now to make sure we
// Create an MLIR execution engine. We use a null MLIR pass manager for now to make sure we
// don't run MLIR passes that were already run. We also pass a default transformer to run
// don't run MLIR passes that were already run. We also pass a default transformer to run
// LLVM optimizations at level 3.
// LLVM optimizations at level 3.
auto
llvm_transformer
=
mlir
::
makeOptimizingTransformer
(
3
/*optLevel*/
,
0
/*sizeLevel*/
);
auto
llvm_transformer
=
mlir
::
makeOptimizingTransformer
(
opt_level
/*optLevel*/
,
0
/*sizeLevel*/
);
auto
maybeEngine
=
mlir
::
ExecutionEngine
::
create
(
m_module
.
get
(),
llvm_transformer
);
auto
maybeEngine
=
mlir
::
ExecutionEngine
::
create
(
m_module
.
get
(),
llvm_transformer
);
NGRAPH_CHECK
(
maybeEngine
,
"failed to construct an execution engine"
);
NGRAPH_CHECK
(
maybeEngine
,
"failed to construct an execution engine"
);
m_engine
=
std
::
move
(
maybeEngine
.
get
());
m_engine
=
std
::
move
(
maybeEngine
.
get
());
...
...
src/contrib/mlir/compiler.hpp
View file @
e762203e
...
@@ -91,6 +91,8 @@ namespace ngraph
...
@@ -91,6 +91,8 @@ namespace ngraph
mlir
::
Type
get_mlir_type
(
const
descriptor
::
Tensor
*
tensor
);
mlir
::
Type
get_mlir_type
(
const
descriptor
::
Tensor
*
tensor
);
mlir
::
Type
get_mlir_type
(
const
element
::
Type
&
type
);
mlir
::
Type
get_mlir_type
(
const
element
::
Type
&
type
);
mlir
::
Type
get_mlir_type
(
const
ngraph
::
Node
*
node
);
TensorInfo
get_tensor_value
(
descriptor
::
Tensor
*
tensor
);
TensorInfo
get_tensor_value
(
descriptor
::
Tensor
*
tensor
);
void
update_tensor_value
(
descriptor
::
Tensor
*
tensor
,
mlir
::
Value
*
value
);
void
update_tensor_value
(
descriptor
::
Tensor
*
tensor
,
mlir
::
Value
*
value
);
...
@@ -106,6 +108,9 @@ namespace ngraph
...
@@ -106,6 +108,9 @@ namespace ngraph
template
<
typename
BinOp
>
template
<
typename
BinOp
>
mlir
::
Value
*
create_binary_op
(
const
ngraph
::
Node
*
ng_node
);
mlir
::
Value
*
create_binary_op
(
const
ngraph
::
Node
*
ng_node
);
template
<
typename
RedOp
>
mlir
::
Value
*
create_index_reduction
(
const
ngraph
::
Node
*
ng_node
);
void
create_return
();
void
create_return
();
/// Helper to create memref arguments for MLIR function signature
/// Helper to create memref arguments for MLIR function signature
...
...
src/contrib/mlir/dialect/ops.cpp
View file @
e762203e
...
@@ -97,7 +97,7 @@ template <typename T>
...
@@ -97,7 +97,7 @@ template <typename T>
static
mlir
::
LogicalResult
verifyIndexReductionOp
(
T
*
op
)
static
mlir
::
LogicalResult
verifyIndexReductionOp
(
T
*
op
)
{
{
// TODO: verifyAxisReductionOp(op) + return element type + single axis.
// TODO: verifyAxisReductionOp(op) + return element type + single axis.
return
mlir
::
failure
();
return
mlir
::
success
();
}
}
template
<
typename
T
>
template
<
typename
T
>
...
...
src/contrib/mlir/dialect/type.hpp
View file @
e762203e
...
@@ -160,6 +160,7 @@ namespace mlir
...
@@ -160,6 +160,7 @@ namespace mlir
static
bool
kindof
(
unsigned
kind
)
{
return
kind
==
NGTypeKind
::
NG_BOOL_TYPE_ID
;
}
static
bool
kindof
(
unsigned
kind
)
{
return
kind
==
NGTypeKind
::
NG_BOOL_TYPE_ID
;
}
static
NGBoolType
get
(
mlir
::
MLIRContext
*
ctx
)
{
return
get
(
NG_BOOL_TYPE_ID
,
ctx
);
}
static
NGBoolType
get
(
mlir
::
MLIRContext
*
ctx
)
{
return
get
(
NG_BOOL_TYPE_ID
,
ctx
);
}
size_t
getWidth
()
{
return
8
;
}
};
};
// Note that dialect types don't add new data members, so always possible
// Note that dialect types don't add new data members, so always possible
...
@@ -222,6 +223,23 @@ namespace mlir
...
@@ -222,6 +223,23 @@ namespace mlir
int
getRank
()
{
return
getShape
().
size
();
}
int
getRank
()
{
return
getShape
().
size
();
}
/// Computes tensor size in bytes
/// Computes tensor size in bytes
size_t
getSizeInBytes
()
size_t
getSizeInBytes
()
{
return
getNumElements
()
*
llvm
::
divideCeil
(
getElementBitWidth
(),
8
);
}
size_t
getElementBitWidth
()
{
Type
type
=
getElementType
();
if
(
NGIntegerType
intType
=
type
.
dyn_cast
<
NGIntegerType
>
())
return
intType
.
getWidth
();
if
(
NGFloatType
floatType
=
type
.
dyn_cast
<
NGFloatType
>
())
return
floatType
.
getIntOrFloatBitWidth
();
if
(
NGBoolType
boolType
=
type
.
dyn_cast
<
NGBoolType
>
())
return
boolType
.
getWidth
();
NGRAPH_CHECK
(
false
,
"Unknown type"
);
return
-
1
;
}
/// Get number of elements
size_t
getNumElements
()
{
{
size_t
s
=
1
;
size_t
s
=
1
;
auto
shape
=
getShape
();
auto
shape
=
getShape
();
...
@@ -232,10 +250,8 @@ namespace mlir
...
@@ -232,10 +250,8 @@ namespace mlir
return
-
1
;
return
-
1
;
s
*=
shape
[
i
];
s
*=
shape
[
i
];
}
}
// Multiply times element size
return
s
;
return
s
*
llvm
::
divideCeil
(
getElementType
().
getIntOrFloatBitWidth
(),
8
);
}
}
/// Checks if two tensors are compatible. Compatible means:
/// Checks if two tensors are compatible. Compatible means:
/// Exactly same element types
/// Exactly same element types
/// Compatible shapes: see isCompatibleShape.
/// Compatible shapes: see isCompatibleShape.
...
...
src/contrib/mlir/lowerer.cpp
View file @
e762203e
...
@@ -37,7 +37,9 @@ namespace
...
@@ -37,7 +37,9 @@ namespace
{
{
using
namespace
mlir
;
using
namespace
mlir
;
using
namespace
mlir
::
edsc
;
using
namespace
mlir
::
edsc
;
using
namespace
mlir
::
edsc
::
op
;
using
namespace
ngraph
::
runtime
;
using
namespace
ngraph
::
runtime
;
using
namespace
ngraph
::
runtime
::
ngmlir
;
class
DialectLoweringPass
;
class
DialectLoweringPass
;
...
@@ -59,6 +61,13 @@ namespace
...
@@ -59,6 +61,13 @@ namespace
#include "op_lowerers.inc"
#include "op_lowerers.inc"
// Helpers
template
<
typename
RedOp
>
void
lowerIndexReduction
(
Operation
*
op
,
ArrayRef
<
Value
*>
operands
,
PatternRewriter
&
rewriter
,
DialectLoweringPass
&
m_pass
);
/// Conversion from types in the nGraph dialect to the Standard dialect.
/// Conversion from types in the nGraph dialect to the Standard dialect.
class
NGraphTypeConverter
:
public
TypeConverter
class
NGraphTypeConverter
:
public
TypeConverter
{
{
...
@@ -82,15 +91,17 @@ namespace
...
@@ -82,15 +91,17 @@ namespace
void
runOnModule
()
override
;
void
runOnModule
()
override
;
SmallVector
<
Value
*
,
4
>
buildOutputDefs
(
Operation
*
op
,
PatternRewriter
&
rewriter
);
SmallVector
<
Value
*
,
4
>
buildOutputDefs
(
Operation
*
op
,
PatternRewriter
&
rewriter
);
Value
*
createTempTensor
(
Type
type
,
unsigned
size
,
PatternRewriter
&
rewriter
);
private
:
/// Collect a set of patterns to convert from the nGraph dialect to Affine dialect.
void
populateNGraphToAffineConversionPatterns
(
OwningRewritePatternList
&
patterns
);
mlir
::
Function
*
getCallDecl
(
StringRef
name
,
mlir
::
Function
*
getCallDecl
(
StringRef
name
,
ArrayRef
<
Type
>
args
,
ArrayRef
<
Type
>
args
,
ArrayRef
<
Type
>
output
,
ArrayRef
<
Type
>
output
,
PatternRewriter
&
rewriter
);
PatternRewriter
&
rewriter
);
private
:
/// Collect a set of patterns to convert from the nGraph dialect to Affine dialect.
void
populateNGraphToAffineConversionPatterns
(
OwningRewritePatternList
&
patterns
);
void
findOutputValues
();
void
findOutputValues
();
void
processFakeInstrs
();
void
processFakeInstrs
();
Value
*
insertMemMgrDef
(
PatternRewriter
*
rewriter
=
nullptr
);
Value
*
insertMemMgrDef
(
PatternRewriter
*
rewriter
=
nullptr
);
...
@@ -136,8 +147,11 @@ namespace
...
@@ -136,8 +147,11 @@ namespace
void
DialectLoweringPass
::
populateNGraphToAffineConversionPatterns
(
void
DialectLoweringPass
::
populateNGraphToAffineConversionPatterns
(
OwningRewritePatternList
&
patterns
)
OwningRewritePatternList
&
patterns
)
{
{
RewriteListBuilder
<
NGAddOpConversion
,
NGDotOpConversion
,
NGReturnOpConversion
>::
build
(
RewriteListBuilder
<
NGAddOpConversion
,
patterns
,
&
getContext
(),
*
this
);
NGArgMaxRedOpConversion
,
NGArgMinRedOpConversion
,
NGDotOpConversion
,
NGReturnOpConversion
>::
build
(
patterns
,
&
getContext
(),
*
this
);
}
}
void
DialectLoweringPass
::
findOutputValues
()
void
DialectLoweringPass
::
findOutputValues
()
...
@@ -206,25 +220,30 @@ namespace
...
@@ -206,25 +220,30 @@ namespace
else
else
{
{
auto
tensorType
=
origResult
->
getType
().
cast
<
NGTensorType
>
();
auto
tensorType
=
origResult
->
getType
().
cast
<
NGTensorType
>
();
auto
callBackFunc
=
getCallDecl
(
"__mlir_allocate"
,
auto
newResult
=
createTempTensor
(
{
rewriter
.
getIndexType
(),
rewriter
.
getIndexType
()},
m_typeConverter
.
convertType
(
tensorType
),
tensorType
.
getSizeInBytes
(),
rewriter
);
{
m_typeConverter
.
convertType
(
tensorType
)},
rewriter
);
auto
size
=
tensorType
.
getSizeInBytes
();
SmallVector
<
mlir
::
Value
*
,
4
>
args
=
{
insertMemMgrDef
(
&
rewriter
),
/* pointer to mem manager */
rewriter
.
create
<
mlir
::
ConstantIndexOp
>
(
rewriter
.
getUnknownLoc
(),
size
)};
/* size to allocate */
auto
newResult
=
rewriter
.
create
<
mlir
::
CallOp
>
(
rewriter
.
getUnknownLoc
(),
callBackFunc
,
args
)
.
getResult
(
0
);
newResults
.
push_back
(
newResult
);
newResults
.
push_back
(
newResult
);
}
}
}
}
return
newResults
;
return
newResults
;
}
}
Value
*
DialectLoweringPass
::
createTempTensor
(
Type
type
,
unsigned
size
,
PatternRewriter
&
rewriter
)
{
auto
callBackFunc
=
getCallDecl
(
"__mlir_allocate"
,
{
rewriter
.
getIndexType
(),
rewriter
.
getIndexType
()},
{
type
},
rewriter
);
SmallVector
<
mlir
::
Value
*
,
4
>
args
=
{
insertMemMgrDef
(
&
rewriter
),
/* pointer to mem manager */
rewriter
.
create
<
mlir
::
ConstantIndexOp
>
(
rewriter
.
getUnknownLoc
(),
size
)};
/* size to allocate */
auto
newTemp
=
rewriter
.
create
<
mlir
::
CallOp
>
(
rewriter
.
getUnknownLoc
(),
callBackFunc
,
args
)
.
getResult
(
0
);
return
newTemp
;
}
void
DialectLoweringPass
::
processFakeInstrs
()
void
DialectLoweringPass
::
processFakeInstrs
()
{
{
auto
context
=
getModule
().
getContext
();
auto
context
=
getModule
().
getContext
();
...
@@ -326,7 +345,6 @@ namespace
...
@@ -326,7 +345,6 @@ namespace
// ADD
// ADD
REWRITER
(
NGAddOp
)
REWRITER
(
NGAddOp
)
{
{
auto
add
=
cast
<
NGAddOp
>
(
op
);
auto
add
=
cast
<
NGAddOp
>
(
op
);
auto
loc
=
add
.
getLoc
();
auto
loc
=
add
.
getLoc
();
...
@@ -365,6 +383,18 @@ namespace
...
@@ -365,6 +383,18 @@ namespace
return
matchSuccess
();
return
matchSuccess
();
}
}
REWRITER
(
NGArgMaxRedOp
)
{
lowerIndexReduction
<
mlir
::
NGArgMaxRedOp
>
(
op
,
operands
,
rewriter
,
m_pass
);
return
matchSuccess
();
}
REWRITER
(
NGArgMinRedOp
)
{
lowerIndexReduction
<
mlir
::
NGArgMinRedOp
>
(
op
,
operands
,
rewriter
,
m_pass
);
return
matchSuccess
();
}
REWRITER
(
NGDotOp
)
REWRITER
(
NGDotOp
)
{
{
auto
dot
=
cast
<
NGDotOp
>
(
op
);
auto
dot
=
cast
<
NGDotOp
>
(
op
);
...
@@ -412,7 +442,7 @@ namespace
...
@@ -412,7 +442,7 @@ namespace
IndexHandle
n_ub
(
v_lhs
.
ub
(
n_dim
)),
m_ub
(
v_lhs
.
ub
(
m_dim
)),
k_ub
(
v_rhs
.
ub
(
k_dim
));
IndexHandle
n_ub
(
v_lhs
.
ub
(
n_dim
)),
m_ub
(
v_lhs
.
ub
(
m_dim
)),
k_ub
(
v_rhs
.
ub
(
k_dim
));
int64_t
n_step
=
v_lhs
.
step
(
n_dim
),
m_step
=
v_lhs
.
step
(
m_dim
),
k_step
=
v_rhs
.
step
(
k_dim
);
int64_t
n_step
=
v_lhs
.
step
(
n_dim
),
m_step
=
v_lhs
.
step
(
m_dim
),
k_step
=
v_rhs
.
step
(
k_dim
);
// Constants
, indexed values and index
es to be used inside the loop nest.
// Constants
and indexed valu
es to be used inside the loop nest.
IndexedValue
i_res
(
result
),
i_lhs
(
lhs
),
i_rhs
(
rhs
);
IndexedValue
i_res
(
result
),
i_lhs
(
lhs
),
i_rhs
(
rhs
);
ValueHandle
zero_init
(
rewriter
.
create
<
ConstantOp
>
(
loc
,
rewriter
.
getZeroAttr
(
elem_ty
)));
ValueHandle
zero_init
(
rewriter
.
create
<
ConstantOp
>
(
loc
,
rewriter
.
getZeroAttr
(
elem_ty
)));
...
@@ -436,6 +466,96 @@ namespace
...
@@ -436,6 +466,96 @@ namespace
}
}
#undef REWRITER
#undef REWRITER
template
<
typename
RedOp
>
void
lowerIndexReduction
(
Operation
*
op
,
ArrayRef
<
Value
*>
operands
,
PatternRewriter
&
rewriter
,
DialectLoweringPass
&
m_pass
)
{
static_assert
(
std
::
is_same
<
RedOp
,
NGArgMinRedOp
>
()
||
std
::
is_same
<
RedOp
,
NGArgMaxRedOp
>
(),
"Template parameter is not supported by lowerIndexReduction"
);
RedOp
redOp
=
cast
<
RedOp
>
(
op
);
auto
loc
=
redOp
.
getLoc
();
auto
axesAttr
=
redOp
.
axes
();
NGRAPH_CHECK
(
axesAttr
.
size
()
==
1
,
"Index Reduction op should have one reduction axis"
);
Attribute
axisAttr
=
*
axesAttr
.
begin
();
unsigned
axis
=
axisAttr
.
dyn_cast
<
IntegerAttr
>
().
getInt
();
NGRAPH_CHECK
(
operands
.
size
()
==
1
&&
operands
[
0
]
!=
nullptr
,
"Expected one non-null operand in Index Reduction op"
);
// Retrieve/generate Values for operands and result.
ScopedContext
scope
(
rewriter
,
loc
);
Value
*
arg
=
operands
[
0
];
Value
*
result
=
m_pass
.
buildOutputDefs
(
op
,
rewriter
)[
0
];
// Views
MemRefView
vRes
(
result
),
vArg
(
arg
);
// Index Values
IndexedValue
iRes
(
result
),
iArg
(
arg
);
// Bounds Index Handles
auto
resLbs
=
vRes
.
getLbs
();
auto
resUbs
=
vRes
.
getUbs
();
auto
argLbs
=
vArg
.
getLbs
();
auto
argUbs
=
vArg
.
getUbs
();
Type
resTy
=
result
->
getType
().
cast
<
MemRefType
>
().
getElementType
();
// Generate loop nest that initializes result to lower bound of the axis to be reduced.
{
auto
ivs
=
IndexHandle
::
makeIndexHandles
(
vRes
.
rank
());
auto
pivs
=
IndexHandle
::
makeIndexHandlePointers
(
ivs
);
auto
steps
=
vRes
.
getSteps
();
auto
initVal
=
vArg
.
lb
(
axis
);
LoopNestBuilder
(
pivs
,
resLbs
,
resUbs
,
steps
)(
[
&
]
{
iRes
(
ivs
)
=
ValueHandle
::
create
<
IndexCastOp
>
(
initVal
,
resTy
);
});
}
// Generate loop nest that computes the actual index reduction.
{
auto
allIVs
=
IndexHandle
::
makeIndexHandles
(
vArg
.
rank
());
auto
pAllIVs
=
IndexHandle
::
makeIndexHandlePointers
(
allIVs
);
auto
steps
=
vArg
.
getSteps
();
SmallVector
<
IndexHandle
,
8
>
nonRedIVs
;
Type
resTy
=
result
->
getType
().
cast
<
MemRefType
>
().
getElementType
();
NGRAPH_CHECK
(
resTy
.
isa
<
IntegerType
>
(),
"Expected integer result type in index reduction"
);
// iterate over all argument dimensions
LoopNestBuilder
(
pAllIVs
,
argLbs
,
argUbs
,
steps
)([
&
]
{
// build a list of non-reduction IVs
for
(
auto
i
=
0
;
i
<
vArg
.
rank
();
i
++
)
{
if
(
i
!=
axis
)
nonRedIVs
.
push_back
(
allIVs
[
i
]);
}
// Load current min index with integer data type and convert it to index data type.
ValueHandle
currRedIdx
=
ValueHandle
::
create
<
IndexCastOp
>
(
(
ValueHandle
)
iRes
(
nonRedIVs
),
IndexType
::
get
(
resTy
.
getContext
()));
// Build list of IVs including current min index.
auto
tempIVs
=
allIVs
;
tempIVs
[
axis
]
=
currRedIdx
;
// Select the min/max value and cast it back to integer type before storing it.
ValueHandle
newRedIdx
=
std
::
is_same
<
RedOp
,
NGArgMinRedOp
>
()
?
edsc
::
intrinsics
::
select
(
iArg
(
allIVs
)
<
iArg
(
tempIVs
),
allIVs
[
axis
],
currRedIdx
)
:
edsc
::
intrinsics
::
select
(
iArg
(
tempIVs
)
<
iArg
(
allIVs
),
allIVs
[
axis
],
currRedIdx
);
iRes
(
nonRedIVs
)
=
ValueHandle
::
create
<
IndexCastOp
>
(
newRedIdx
,
resTy
);
});
}
rewriter
.
replaceOp
(
op
,
result
);
}
}
}
namespace
mlir
namespace
mlir
...
...
src/contrib/mlir/op_lowerers.inc
View file @
e762203e
...
@@ -32,6 +32,8 @@
...
@@ -32,6 +32,8 @@
};
};
DECL_OP_CONV
(
NGAddOp
)
DECL_OP_CONV
(
NGAddOp
)
DECL_OP_CONV
(
NGArgMaxRedOp
)
DECL_OP_CONV
(
NGArgMinRedOp
)
DECL_OP_CONV
(
NGDotOp
)
DECL_OP_CONV
(
NGDotOp
)
DECL_OP_CONV
(
NGReturnOp
)
DECL_OP_CONV
(
NGReturnOp
)
...
...
src/contrib/mlir/ops_supported.inc
View file @
e762203e
...
@@ -4,6 +4,8 @@
...
@@ -4,6 +4,8 @@
#endif
#endif
MLIR_OP
(
Add
)
MLIR_OP
(
Add
)
MLIR_OP
(
ArgMin
)
MLIR_OP
(
ArgMax
)
MLIR_OP
(
Dot
)
MLIR_OP
(
Dot
)
// Add new supported ops here
// Add new supported ops here
...
...
src/contrib/mlir/pass/mlir_subgraph_extraction.cpp
View file @
e762203e
...
@@ -15,9 +15,12 @@
...
@@ -15,9 +15,12 @@
//*****************************************************************************
//*****************************************************************************
#include "mlir_subgraph_extraction.hpp"
#include "mlir_subgraph_extraction.hpp"
#include "ngraph/assertion.hpp"
#include "ngraph/assertion.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/experimental/compiled_kernel.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/get_output_element.hpp"
...
@@ -105,6 +108,13 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
...
@@ -105,6 +108,13 @@ bool MLIRSubgraphExtractionPass::is_supported_mlir_op(std::shared_ptr<Node> node
return
false
;
return
false
;
}
}
}
}
if
(
TI
(
ngraph
::
op
::
ArgMin
)
==
TI
(
*
node
)
||
TI
(
ngraph
::
op
::
ArgMax
)
==
TI
(
*
node
))
{
// TODO: Remove this when MLIR has float point cmp support
if
(
!
node
->
input
(
0
).
get_element_type
().
is_integral
())
return
false
;
}
return
true
;
return
true
;
}
}
...
...
src/ngraph/runtime/plaidml/unit_test.manifest
View file @
e762203e
...
@@ -240,6 +240,10 @@ batch_norm_training_0eps_f32
...
@@ -240,6 +240,10 @@ batch_norm_training_0eps_f32
argmin_trivial
argmin_trivial
argmax_trivial
argmax_trivial
argmin_trivial_in_i32
argmin_trivial_in_i32
argmin_3D_i32
argmin_3D_i64
argmax_3D_i32
argmax_3D_i64
sum_large_1d_to_scalar
sum_large_1d_to_scalar
sum_stable_acc
sum_stable_acc
one_hot_scalar_2_in_3
one_hot_scalar_2_in_3
...
...
test/backend_arg_reduce.in.cpp
View file @
e762203e
...
@@ -55,6 +55,107 @@ NGRAPH_TEST(${BACKEND_NAME}, argmin_trivial)
...
@@ -55,6 +55,107 @@ NGRAPH_TEST(${BACKEND_NAME}, argmin_trivial)
EXPECT_EQ
((
vector
<
int
>
{
3
,
2
,
1
}),
read_vector
<
int
>
(
result
));
EXPECT_EQ
((
vector
<
int
>
{
3
,
2
,
1
}),
read_vector
<
int
>
(
result
));
}
}
NGRAPH_TEST
(
$
{
BACKEND_NAME
},
argmin_2D_i32
)
{
Shape
shape
{
4
,
3
};
Shape
rshape
{
3
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
shape
);
auto
f
=
make_shared
<
Function
>
(
make_shared
<
op
::
ArgMin
>
(
A
,
0
,
element
::
i32
),
ParameterVector
{
A
});
auto
backend
=
runtime
::
Backend
::
create
(
"${BACKEND_NAME}"
);
// Create some tensors for input/output
auto
a
=
backend
->
create_tensor
(
element
::
i32
,
shape
);
copy_data
(
a
,
vector
<
int
>
{
12
,
2
,
10
,
9
,
8
,
4
,
6
,
1
,
5
,
3
,
11
,
7
});
auto
result
=
backend
->
create_tensor
(
element
::
i32
,
rshape
);
auto
handle
=
backend
->
compile
(
f
);
handle
->
call_with_validate
({
result
},
{
a
});
EXPECT_EQ
((
vector
<
int
>
{
3
,
2
,
1
}),
read_vector
<
int
>
(
result
));
}
NGRAPH_TEST
(
$
{
BACKEND_NAME
},
argmin_3D_i32
)
{
Shape
shape
{
3
,
3
,
4
};
Shape
rshape
{
3
,
4
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
shape
);
auto
f
=
make_shared
<
Function
>
(
make_shared
<
op
::
ArgMin
>
(
A
,
1
,
element
::
i32
),
ParameterVector
{
A
});
auto
backend
=
runtime
::
Backend
::
create
(
"${BACKEND_NAME}"
);
// Create some tensors for input/output
auto
a
=
backend
->
create_tensor
(
element
::
i32
,
shape
);
copy_data
(
a
,
test
::
NDArray
<
int
,
3
>
({{{
12
,
2
,
10
,
9
},
{
3
,
5
,
0
,
8
},
{
7
,
9
,
1
,
5
}},
{{
7
,
2
,
4
,
10
},
{
6
,
10
,
2
,
2
},
{
12
,
1
,
1
,
1
}},
{{
10
,
2
,
2
,
4
},
{
1
,
5
,
5
,
1
},
{
7
,
12
,
2
,
2
}}})
.
get_vector
());
auto
result
=
backend
->
create_tensor
(
element
::
i32
,
rshape
);
auto
handle
=
backend
->
compile
(
f
);
handle
->
call_with_validate
({
result
},
{
a
});
EXPECT_EQ
((
vector
<
int
>
{
1
,
0
,
1
,
2
,
1
,
2
,
2
,
2
,
1
,
0
,
0
,
1
}),
read_vector
<
int
>
(
result
));
}
NGRAPH_TEST
(
$
{
BACKEND_NAME
},
argmin_3D_i64
)
{
Shape
shape
{
3
,
3
,
4
};
Shape
rshape
{
3
,
4
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
shape
);
auto
f
=
make_shared
<
Function
>
(
make_shared
<
op
::
ArgMin
>
(
A
,
1
,
element
::
i64
),
ParameterVector
{
A
});
auto
backend
=
runtime
::
Backend
::
create
(
"${BACKEND_NAME}"
);
// Create some tensors for input/output
auto
a
=
backend
->
create_tensor
(
element
::
i32
,
shape
);
copy_data
(
a
,
test
::
NDArray
<
int
,
3
>
({{{
12
,
2
,
10
,
9
},
{
3
,
5
,
0
,
8
},
{
7
,
9
,
1
,
5
}},
{{
7
,
2
,
4
,
10
},
{
6
,
10
,
2
,
2
},
{
12
,
1
,
1
,
1
}},
{{
10
,
2
,
2
,
4
},
{
1
,
5
,
5
,
1
},
{
7
,
12
,
2
,
2
}}})
.
get_vector
());
auto
result
=
backend
->
create_tensor
(
element
::
i64
,
rshape
);
auto
handle
=
backend
->
compile
(
f
);
handle
->
call_with_validate
({
result
},
{
a
});
EXPECT_EQ
((
vector
<
int64_t
>
{
1
,
0
,
1
,
2
,
1
,
2
,
2
,
2
,
1
,
0
,
0
,
1
}),
read_vector
<
int64_t
>
(
result
));
}
NGRAPH_TEST
(
$
{
BACKEND_NAME
},
argmin_4D_i64
)
{
Shape
shape
{
2
,
2
,
5
,
5
};
// NCHW ->(0,1,2,3)
Shape
rshape
{
2
,
2
,
5
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape
);
auto
f
=
make_shared
<
Function
>
(
make_shared
<
op
::
ArgMin
>
(
A
,
3
,
element
::
i64
),
ParameterVector
{
A
});
auto
backend
=
runtime
::
Backend
::
create
(
"${BACKEND_NAME}"
);
// Create some tensors for input/output
auto
a
=
backend
->
create_tensor
(
element
::
f32
,
shape
);
copy_data
(
a
,
test
::
NDArray
<
int
,
4
>
(
{{{{
3
,
1
,
1
,
2
,
105
},
{
0
,
3
,
2
,
1
,
2
},
{
2
,
4
,
2
,
0
,
1
},
{
2
,
5
,
1
,
1
,
22
},
{
5
,
2
,
1
,
7
,
5
}},
{{
3
,
1
,
2
,
2
,
1
},
{
1
,
7
,
3
,
8
,
1
},
{
2
,
10
,
1
,
3
,
2
},
{
3
,
1
,
0
,
0
,
6
},
{
2
,
0
,
0
,
0
,
0
}}},
{{{
0
,
2
,
1
,
1
,
0
},
{
0
,
0
,
0
,
0
,
1
},
{
0
,
0
,
1
,
0
,
3
},
{
2
,
0
,
0
,
3
,
0
},
{
0
,
0
,
0
,
0
,
1
}},
{{
2
,
1
,
0
,
0
,
1
},
{
0
,
2
,
0
,
0
,
0
},
{
1
,
1
,
2
,
0
,
2
},
{
1
,
1
,
1
,
0
,
1
},
{
1
,
0
,
0
,
0
,
2
}}}})
.
get_vector
());
auto
result
=
backend
->
create_tensor
(
element
::
i64
,
rshape
);
auto
handle
=
backend
->
compile
(
f
);
handle
->
call_with_validate
({
result
},
{
a
});
EXPECT_EQ
((
vector
<
int64_t
>
{
1
,
0
,
3
,
2
,
2
,
1
,
0
,
2
,
2
,
1
,
0
,
0
,
0
,
1
,
0
,
2
,
0
,
3
,
3
,
1
}),
read_vector
<
int64_t
>
(
result
));
}
NGRAPH_TEST
(
$
{
BACKEND_NAME
},
argmin_4D_axis_3_i64
)
NGRAPH_TEST
(
$
{
BACKEND_NAME
},
argmin_4D_axis_3_i64
)
{
{
Shape
shape
{
2
,
2
,
5
,
5
};
// NCHW ->(0,1,2,3)
Shape
shape
{
2
,
2
,
5
,
5
};
// NCHW ->(0,1,2,3)
...
@@ -158,6 +259,107 @@ NGRAPH_TEST(${BACKEND_NAME}, argmax_trivial)
...
@@ -158,6 +259,107 @@ NGRAPH_TEST(${BACKEND_NAME}, argmax_trivial)
EXPECT_EQ
((
vector
<
int
>
{
1
,
3
,
0
}),
read_vector
<
int
>
(
result
));
EXPECT_EQ
((
vector
<
int
>
{
1
,
3
,
0
}),
read_vector
<
int
>
(
result
));
}
}
NGRAPH_TEST
(
$
{
BACKEND_NAME
},
argmax_2D_i32
)
{
Shape
shape
{
4
,
3
};
Shape
rshape
{
3
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
shape
);
auto
f
=
make_shared
<
Function
>
(
make_shared
<
op
::
ArgMax
>
(
A
,
0
,
element
::
i32
),
ParameterVector
{
A
});
auto
backend
=
runtime
::
Backend
::
create
(
"${BACKEND_NAME}"
);
// Create some tensors for input/output
auto
a
=
backend
->
create_tensor
(
element
::
i32
,
shape
);
copy_data
(
a
,
vector
<
int
>
{
12
,
2
,
10
,
9
,
8
,
4
,
6
,
1
,
5
,
3
,
11
,
7
});
auto
result
=
backend
->
create_tensor
(
element
::
i32
,
rshape
);
auto
handle
=
backend
->
compile
(
f
);
handle
->
call_with_validate
({
result
},
{
a
});
EXPECT_EQ
((
vector
<
int
>
{
0
,
3
,
0
}),
read_vector
<
int
>
(
result
));
}
NGRAPH_TEST
(
$
{
BACKEND_NAME
},
argmax_3D_i32
)
{
Shape
shape
{
3
,
3
,
4
};
Shape
rshape
{
3
,
4
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
shape
);
auto
f
=
make_shared
<
Function
>
(
make_shared
<
op
::
ArgMax
>
(
A
,
1
,
element
::
i32
),
ParameterVector
{
A
});
auto
backend
=
runtime
::
Backend
::
create
(
"${BACKEND_NAME}"
);
// Create some tensors for input/output
auto
a
=
backend
->
create_tensor
(
element
::
i32
,
shape
);
copy_data
(
a
,
test
::
NDArray
<
int
,
3
>
({{{
12
,
2
,
10
,
9
},
{
3
,
5
,
0
,
8
},
{
7
,
9
,
1
,
5
}},
{{
7
,
2
,
4
,
10
},
{
6
,
10
,
2
,
2
},
{
12
,
1
,
1
,
1
}},
{{
10
,
2
,
2
,
4
},
{
1
,
5
,
5
,
1
},
{
7
,
12
,
2
,
2
}}})
.
get_vector
());
auto
result
=
backend
->
create_tensor
(
element
::
i32
,
rshape
);
auto
handle
=
backend
->
compile
(
f
);
handle
->
call_with_validate
({
result
},
{
a
});
EXPECT_EQ
((
vector
<
int
>
{
0
,
2
,
0
,
0
,
2
,
1
,
0
,
0
,
0
,
2
,
1
,
0
}),
read_vector
<
int
>
(
result
));
}
NGRAPH_TEST
(
$
{
BACKEND_NAME
},
argmax_3D_i64
)
{
Shape
shape
{
3
,
3
,
4
};
Shape
rshape
{
3
,
4
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
shape
);
auto
f
=
make_shared
<
Function
>
(
make_shared
<
op
::
ArgMax
>
(
A
,
1
,
element
::
i64
),
ParameterVector
{
A
});
auto
backend
=
runtime
::
Backend
::
create
(
"${BACKEND_NAME}"
);
// Create some tensors for input/output
auto
a
=
backend
->
create_tensor
(
element
::
i32
,
shape
);
copy_data
(
a
,
test
::
NDArray
<
int
,
3
>
({{{
12
,
2
,
10
,
9
},
{
3
,
5
,
0
,
8
},
{
7
,
9
,
1
,
5
}},
{{
7
,
2
,
4
,
10
},
{
6
,
10
,
2
,
2
},
{
12
,
1
,
1
,
1
}},
{{
10
,
2
,
2
,
4
},
{
1
,
5
,
5
,
1
},
{
7
,
12
,
2
,
2
}}})
.
get_vector
());
auto
result
=
backend
->
create_tensor
(
element
::
i64
,
rshape
);
auto
handle
=
backend
->
compile
(
f
);
handle
->
call_with_validate
({
result
},
{
a
});
EXPECT_EQ
((
vector
<
int64_t
>
{
0
,
2
,
0
,
0
,
2
,
1
,
0
,
0
,
0
,
2
,
1
,
0
}),
read_vector
<
int64_t
>
(
result
));
}
NGRAPH_TEST
(
$
{
BACKEND_NAME
},
argmax_4D_i64
)
{
Shape
shape
{
2
,
2
,
5
,
5
};
// NCHW ->(0,1,2,3)
Shape
rshape
{
2
,
2
,
5
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape
);
auto
f
=
make_shared
<
Function
>
(
make_shared
<
op
::
ArgMax
>
(
A
,
3
,
element
::
i64
),
ParameterVector
{
A
});
auto
backend
=
runtime
::
Backend
::
create
(
"${BACKEND_NAME}"
);
// Create some tensors for input/output
auto
a
=
backend
->
create_tensor
(
element
::
f32
,
shape
);
copy_data
(
a
,
test
::
NDArray
<
int
,
4
>
(
{{{{
3
,
1
,
1
,
2
,
105
},
{
0
,
3
,
2
,
1
,
2
},
{
2
,
4
,
2
,
0
,
1
},
{
2
,
5
,
1
,
1
,
22
},
{
5
,
2
,
1
,
7
,
5
}},
{{
3
,
1
,
2
,
2
,
1
},
{
1
,
7
,
3
,
8
,
1
},
{
2
,
10
,
1
,
3
,
2
},
{
3
,
1
,
0
,
0
,
6
},
{
2
,
0
,
0
,
0
,
0
}}},
{{{
0
,
2
,
1
,
1
,
0
},
{
0
,
0
,
0
,
0
,
1
},
{
0
,
0
,
1
,
0
,
3
},
{
2
,
0
,
0
,
3
,
0
},
{
0
,
0
,
0
,
0
,
1
}},
{{
2
,
1
,
0
,
0
,
1
},
{
0
,
2
,
0
,
0
,
0
},
{
1
,
1
,
2
,
0
,
2
},
{
1
,
1
,
1
,
0
,
1
},
{
1
,
0
,
0
,
0
,
2
}}}})
.
get_vector
());
auto
result
=
backend
->
create_tensor
(
element
::
i64
,
rshape
);
auto
handle
=
backend
->
compile
(
f
);
handle
->
call_with_validate
({
result
},
{
a
});
EXPECT_EQ
((
vector
<
int64_t
>
{
4
,
1
,
1
,
4
,
3
,
0
,
3
,
1
,
4
,
0
,
1
,
4
,
4
,
3
,
4
,
0
,
1
,
2
,
0
,
4
}),
read_vector
<
int64_t
>
(
result
));
}
NGRAPH_TEST
(
$
{
BACKEND_NAME
},
argmax_3D_axis_0
)
// Along Channels
NGRAPH_TEST
(
$
{
BACKEND_NAME
},
argmax_3D_axis_0
)
// Along Channels
{
{
Shape
shape
{
3
,
4
,
2
};
// CHW ->(0,1,2)
Shape
shape
{
3
,
4
,
2
};
// CHW ->(0,1,2)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment