Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
1eb9f9bf
Unverified
Commit
1eb9f9bf
authored
Apr 24, 2018
by
Robert Kimball
Committed by
GitHub
Apr 24, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Update to enable pass backend unit tests (#904)
* get all ops working * enable autodiff tests for IE backend
parent
a8a68452
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
763 additions
and
635 deletions
+763
-635
ie_backend.cpp
src/ngraph/runtime/ie/ie_backend.cpp
+17
-2
ie_backend.hpp
src/ngraph/runtime/ie/ie_backend.hpp
+745
-622
select.hpp
src/ngraph/runtime/reference/select.hpp
+1
-0
autodiff.in.cpp
test/autodiff.in.cpp
+0
-4
backend_test.in.cpp
test/backend_test.in.cpp
+0
-7
No files found.
src/ngraph/runtime/ie/ie_backend.cpp
View file @
1eb9f9bf
...
...
@@ -17,6 +17,9 @@
#include "ngraph/runtime/ie/ie_backend.hpp"
#include "ngraph/descriptor/layout/dense_tensor_view_layout.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/util/binary_elementwise_comparison.hpp"
#include "ngraph/pass/assign_layout.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
...
...
@@ -145,11 +148,23 @@ bool runtime::ie::IE_Backend::call(shared_ptr<Function> function,
}
// get op type
element
::
Type
type
=
op
->
get_element_type
();
if
(
!
op
->
get_inputs
().
empty
())
element
::
Type
type
;
if
(
dynamic_pointer_cast
<
op
::
util
::
BinaryElementwiseComparison
>
(
op
)
||
dynamic_pointer_cast
<
op
::
Select
>
(
op
))
{
// Get the type of the second input, not the first
// All BinaryElementwiseComparision ops have the same type for inputs
// Select has bool for first input and the type we are interested in for the second
type
=
op
->
get_inputs
().
at
(
1
).
get_tensor
().
get_element_type
();
}
else
if
(
dynamic_pointer_cast
<
op
::
Convert
>
(
op
))
{
type
=
op
->
get_inputs
().
at
(
0
).
get_tensor
().
get_element_type
();
}
else
{
type
=
op
->
get_element_type
();
}
generate_calls
(
type
,
*
op
,
op_outputs
,
op_inputs
);
...
...
src/ngraph/runtime/ie/ie_backend.hpp
View file @
1eb9f9bf
...
...
@@ -47,6 +47,7 @@
#include "ngraph/op/softmax.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/select_and_scatter.hpp"
#include "ngraph/runtime/reference/abs.hpp"
#include "ngraph/runtime/reference/acos.hpp"
#include "ngraph/runtime/reference/add.hpp"
...
...
@@ -58,6 +59,7 @@
#include "ngraph/runtime/reference/ceiling.hpp"
#include "ngraph/runtime/reference/concat.hpp"
#include "ngraph/runtime/reference/constant.hpp"
#include "ngraph/runtime/reference/convert.hpp"
#include "ngraph/runtime/reference/convolution.hpp"
#include "ngraph/runtime/reference/copy.hpp"
#include "ngraph/runtime/reference/cos.hpp"
...
...
@@ -93,6 +95,8 @@
#include "ngraph/runtime/reference/reshape.hpp"
#include "ngraph/runtime/reference/result.hpp"
#include "ngraph/runtime/reference/reverse.hpp"
#include "ngraph/runtime/reference/select.hpp"
#include "ngraph/runtime/reference/select_and_scatter.hpp"
#include "ngraph/runtime/reference/sign.hpp"
#include "ngraph/runtime/reference/sin.hpp"
#include "ngraph/runtime/reference/sinh.hpp"
...
...
@@ -114,649 +118,768 @@ namespace ngraph
{
namespace
ie
{
class
IE_Backend
:
public
Backend
{
public
:
std
::
shared_ptr
<
TensorView
>
create_tensor
(
const
element
::
Type
&
type
,
const
Shape
&
shape
,
void
*
memory_pointer
)
override
;
class
IE_Backend
;
}
}
}
class
ngraph
::
runtime
::
ie
::
IE_Backend
:
public
Backend
{
public
:
std
::
shared_ptr
<
TensorView
>
create_tensor
(
const
element
::
Type
&
type
,
const
Shape
&
shape
,
void
*
memory_pointer
)
override
;
std
::
shared_ptr
<
TensorView
>
create_tensor
(
const
element
::
Type
&
type
,
const
Shape
&
shape
)
override
;
std
::
shared_ptr
<
TensorView
>
create_tensor
(
const
element
::
Type
&
type
,
const
Shape
&
shape
)
override
;
bool
compile
(
std
::
shared_ptr
<
Function
>
function
)
override
;
bool
compile
(
std
::
shared_ptr
<
Function
>
function
)
override
;
bool
call
(
std
::
shared_ptr
<
Function
>
function
,
const
std
::
vector
<
std
::
shared_ptr
<
TensorView
>>&
outputs
,
const
std
::
vector
<
std
::
shared_ptr
<
TensorView
>>&
intputs
)
override
;
bool
call
(
std
::
shared_ptr
<
Function
>
function
,
const
std
::
vector
<
std
::
shared_ptr
<
TensorView
>>&
outputs
,
const
std
::
vector
<
std
::
shared_ptr
<
TensorView
>>&
intputs
)
override
;
private
:
static
bool
init
;
void
generate_calls
(
const
element
::
Type
&
type
,
Node
&
op
,
const
std
::
vector
<
std
::
shared_ptr
<
HostTensorView
>>&
outputs
,
const
std
::
vector
<
std
::
shared_ptr
<
HostTensorView
>>&
inputs
);
private
:
static
bool
init
;
void
generate_calls
(
const
element
::
Type
&
type
,
Node
&
op
,
const
std
::
vector
<
std
::
shared_ptr
<
HostTensorView
>>&
outputs
,
const
std
::
vector
<
std
::
shared_ptr
<
HostTensorView
>>&
inputs
);
template
<
typename
T
>
void
op_engine
(
Node
&
node
,
const
std
::
vector
<
std
::
shared_ptr
<
HostTensorView
>>&
out
,
const
std
::
vector
<
std
::
shared_ptr
<
HostTensorView
>>&
args
)
{
std
::
string
node_op
=
node
.
description
();
if
(
node_op
==
"Abs"
)
{
reference
::
abs
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Acos"
)
{
reference
::
acos
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Add"
)
{
reference
::
add
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
template
<
typename
T
>
void
op_engine
(
Node
&
node
,
const
std
::
vector
<
std
::
shared_ptr
<
HostTensorView
>>&
out
,
const
std
::
vector
<
std
::
shared_ptr
<
HostTensorView
>>&
args
)
{
std
::
string
node_op
=
node
.
description
();
if
(
node_op
==
"Abs"
)
{
reference
::
abs
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Acos"
)
{
reference
::
acos
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Add"
)
{
reference
::
add
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
#ifdef NGRAPH_DISTRIBUTED
else
if
(
node_op
==
"AllReduce"
)
{
reference
::
allreduce
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_element_type
(),
static_cast
<
int
>
(
args
[
0
]
->
get_element_count
()));
}
else
if
(
node_op
==
"AllReduce"
)
{
reference
::
allreduce
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_element_type
(),
static_cast
<
int
>
(
args
[
0
]
->
get_element_count
()));
}
#endif
else
if
(
node_op
==
"And"
)
{
reference
::
logical_and
(
reinterpret_cast
<
char
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Asin"
)
{
reference
::
asin
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Atan"
)
{
reference
::
atan
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"AvgPool"
)
{
op
::
AvgPool
*
avg_pool
=
dynamic_cast
<
op
::
AvgPool
*>
(
&
node
);
reference
::
avg_pool
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
avg_pool
->
get_window_shape
(),
avg_pool
->
get_window_movement_strides
(),
avg_pool
->
get_padding_below
(),
avg_pool
->
get_padding_above
(),
avg_pool
->
get_include_padding_in_avg_computation
());
}
else
if
(
node_op
==
"AvgPoolBackprop"
)
{
op
::
AvgPoolBackprop
*
apb
=
dynamic_cast
<
op
::
AvgPoolBackprop
*>
(
&
node
);
reference
::
avg_pool_backprop
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
apb
->
get_window_shape
(),
apb
->
get_window_movement_strides
(),
apb
->
get_padding_below
(),
apb
->
get_padding_above
(),
apb
->
get_include_padding_in_avg_computation
());
}
else
if
(
node_op
==
"Broadcast"
)
{
op
::
Broadcast
*
broadcast
=
dynamic_cast
<
op
::
Broadcast
*>
(
&
node
);
Shape
in_shape
=
args
[
0
]
->
get_shape
();
Shape
out_shape
=
out
[
0
]
->
get_shape
();
AxisSet
broadcast_axes
=
broadcast
->
get_broadcast_axes
();
reference
::
broadcast
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
in_shape
,
out_shape
,
broadcast_axes
);
}
else
if
(
node_op
==
"Ceiling"
)
{
reference
::
ceiling
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Concat"
)
{
const
op
::
Concat
*
concat
=
static_cast
<
const
op
::
Concat
*>
(
&
node
);
std
::
vector
<
const
T
*>
in_args
;
std
::
vector
<
Shape
>
in_shapes
;
for
(
std
::
shared_ptr
<
HostTensorView
>
arg
:
args
)
{
in_args
.
push_back
(
reinterpret_cast
<
T
*>
(
arg
->
get_data_ptr
()));
in_shapes
.
push_back
(
arg
->
get_shape
());
}
reference
::
concat
<
T
>
(
in_args
,
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
in_shapes
,
out
[
0
]
->
get_shape
(),
concat
->
get_concatenation_axis
());
}
else
if
(
node_op
==
"Constant"
)
{
const
op
::
Constant
*
c
=
static_cast
<
const
op
::
Constant
*>
(
&
node
);
reference
::
constant
<
T
>
(
reinterpret_cast
<
const
T
*>
(
c
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Convolution"
)
{
auto
c
=
static_cast
<
const
op
::
Convolution
*>
(
&
node
);
reference
::
convolution
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
args
[
1
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
c
->
get_window_movement_strides
(),
c
->
get_window_dilation_strides
(),
c
->
get_padding_below
(),
c
->
get_padding_above
(),
c
->
get_data_dilation_strides
(),
0
,
1
,
1
,
0
,
0
,
1
,
false
);
}
else
if
(
node_op
==
"ConvolutionBackpropFilters"
)
{
auto
c
=
static_cast
<
const
op
::
ConvolutionBackpropFilters
*>
(
&
node
);
reference
::
convolution
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
args
[
1
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
c
->
get_window_movement_strides_backward
(),
c
->
get_window_dilation_strides_backward
(),
c
->
get_padding_below_backward
(),
c
->
get_padding_above_backward
(),
c
->
get_data_dilation_strides_backward
(),
1
,
0
,
0
,
1
,
1
,
0
,
false
);
}
else
if
(
node_op
==
"ConvolutionBackpropData"
)
{
// Note that args[1] and args[0] are switched here from the usual order.
auto
c
=
static_cast
<
const
op
::
ConvolutionBackpropData
*>
(
&
node
);
reference
::
convolution
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
1
]
->
get_shape
(),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
c
->
get_window_movement_strides_backward
(),
c
->
get_window_dilation_strides_backward
(),
c
->
get_padding_below_backward
(),
c
->
get_padding_above_backward
(),
c
->
get_data_dilation_strides_backward
(),
0
,
1
,
0
,
1
,
0
,
1
,
true
);
}
else
if
(
node_op
==
"Cos"
)
{
reference
::
cos
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Cosh"
)
{
reference
::
cosh
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Divide"
)
{
reference
::
divide
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Dot"
)
{
op
::
Dot
*
dot
=
dynamic_cast
<
op
::
Dot
*>
(
&
node
);
reference
::
dot
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
args
[
1
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
dot
->
get_reduction_axes_count
());
}
else
if
(
node_op
==
"And"
)
{
reference
::
logical_and
(
reinterpret_cast
<
char
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Asin"
)
{
reference
::
asin
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Atan"
)
{
reference
::
atan
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"AvgPool"
)
{
op
::
AvgPool
*
avg_pool
=
dynamic_cast
<
op
::
AvgPool
*>
(
&
node
);
else
if
(
node_op
==
"Equal"
)
{
reference
::
equal
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Exp"
)
{
reference
::
exp
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Floor"
)
{
reference
::
floor
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reference
::
avg_pool
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
avg_pool
->
get_window_shape
(),
avg_pool
->
get_window_movement_strides
(),
avg_pool
->
get_padding_below
(),
avg_pool
->
get_padding_above
(),
avg_pool
->
get_include_padding_in_avg_computation
());
}
else
if
(
node_op
==
"AvgPoolBackprop"
)
{
op
::
AvgPoolBackprop
*
apb
=
dynamic_cast
<
op
::
AvgPoolBackprop
*>
(
&
node
);
reference
::
avg_pool_backprop
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"FunctionCall"
)
{
std
::
shared_ptr
<
Function
>
function
=
node
.
get_functions
()[
0
];
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
apb
->
get_window_shape
(),
apb
->
get_window_movement_strides
(),
apb
->
get_padding_below
(),
apb
->
get_padding_above
(),
apb
->
get_include_padding_in_avg_computation
());
}
else
if
(
node_op
==
"Broadcast"
)
{
op
::
Broadcast
*
broadcast
=
dynamic_cast
<
op
::
Broadcast
*>
(
&
node
);
Shape
in_shape
=
args
[
0
]
->
get_shape
();
Shape
out_shape
=
out
[
0
]
->
get_shape
();
AxisSet
broadcast_axes
=
broadcast
->
get_broadcast_axes
();
reference
::
broadcast
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
in_shape
,
out_shape
,
broadcast_axes
);
}
else
if
(
node_op
==
"Ceiling"
)
{
reference
::
ceiling
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Concat"
)
{
const
op
::
Concat
*
concat
=
static_cast
<
const
op
::
Concat
*>
(
&
node
);
std
::
vector
<
const
T
*>
in_args
;
std
::
vector
<
Shape
>
in_shapes
;
for
(
std
::
shared_ptr
<
HostTensorView
>
arg
:
args
)
{
in_args
.
push_back
(
reinterpret_cast
<
T
*>
(
arg
->
get_data_ptr
()));
in_shapes
.
push_back
(
arg
->
get_shape
());
}
reference
::
concat
<
T
>
(
in_args
,
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
in_shapes
,
out
[
0
]
->
get_shape
(),
concat
->
get_concatenation_axis
());
}
else
if
(
node_op
==
"Constant"
)
{
const
op
::
Constant
*
c
=
static_cast
<
const
op
::
Constant
*>
(
&
node
);
reference
::
constant
<
T
>
(
reinterpret_cast
<
const
T
*>
(
c
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Convert"
)
{
// const op::Convert* c = static_cast<const op::Convert*>(&node);
element
::
Type
type
=
node
.
get_element_type
();
if
(
type
==
element
::
boolean
)
{
reference
::
convert
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
type
==
element
::
f32
)
{
reference
::
convert
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
float
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
type
==
element
::
f64
)
{
reference
::
convert
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
double
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
type
==
element
::
i8
)
{
reference
::
convert
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
int8_t
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
type
==
element
::
i16
)
{
reference
::
convert
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
int16_t
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
type
==
element
::
i32
)
{
reference
::
convert
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
int32_t
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
type
==
element
::
i64
)
{
reference
::
convert
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
int64_t
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
type
==
element
::
u8
)
{
reference
::
convert
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
uint8_t
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
type
==
element
::
u16
)
{
reference
::
convert
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
uint16_t
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
type
==
element
::
u32
)
{
reference
::
convert
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
uint32_t
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
type
==
element
::
u64
)
{
reference
::
convert
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
uint64_t
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
{
std
::
stringstream
ss
;
ss
<<
"unsupported element type "
<<
type
<<
" op Convert"
;
throw
std
::
runtime_error
(
ss
.
str
());
}
}
else
if
(
node_op
==
"Convolution"
)
{
auto
c
=
static_cast
<
const
op
::
Convolution
*>
(
&
node
);
reference
::
convolution
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
args
[
1
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
c
->
get_window_movement_strides
(),
c
->
get_window_dilation_strides
(),
c
->
get_padding_below
(),
c
->
get_padding_above
(),
c
->
get_data_dilation_strides
(),
0
,
1
,
1
,
0
,
0
,
1
,
false
);
}
else
if
(
node_op
==
"ConvolutionBackpropFilters"
)
{
auto
c
=
static_cast
<
const
op
::
ConvolutionBackpropFilters
*>
(
&
node
);
reference
::
convolution
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
args
[
1
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
c
->
get_window_movement_strides_backward
(),
c
->
get_window_dilation_strides_backward
(),
c
->
get_padding_below_backward
(),
c
->
get_padding_above_backward
(),
c
->
get_data_dilation_strides_backward
(),
1
,
0
,
0
,
1
,
1
,
0
,
false
);
}
else
if
(
node_op
==
"ConvolutionBackpropData"
)
{
// Note that args[1] and args[0] are switched here from the usual order.
auto
c
=
static_cast
<
const
op
::
ConvolutionBackpropData
*>
(
&
node
);
reference
::
convolution
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
1
]
->
get_shape
(),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
c
->
get_window_movement_strides_backward
(),
c
->
get_window_dilation_strides_backward
(),
c
->
get_padding_below_backward
(),
c
->
get_padding_above_backward
(),
c
->
get_data_dilation_strides_backward
(),
0
,
1
,
0
,
1
,
0
,
1
,
true
);
}
else
if
(
node_op
==
"Cos"
)
{
reference
::
cos
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Cosh"
)
{
reference
::
cosh
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Divide"
)
{
reference
::
divide
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Dot"
)
{
op
::
Dot
*
dot
=
dynamic_cast
<
op
::
Dot
*>
(
&
node
);
std
::
vector
<
std
::
shared_ptr
<
runtime
::
TensorView
>>
outputs
;
for
(
auto
tv
:
out
)
{
outputs
.
push_back
(
std
::
static_pointer_cast
<
runtime
::
TensorView
>
(
tv
));
}
reference
::
dot
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
args
[
1
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
dot
->
get_reduction_axes_count
());
}
std
::
vector
<
std
::
shared_ptr
<
runtime
::
TensorView
>>
inputs
;
for
(
auto
tv
:
args
)
{
inputs
.
push_back
(
std
::
static_pointer_cast
<
runtime
::
TensorView
>
(
tv
));
}
else
if
(
node_op
==
"Equal"
)
{
reference
::
equal
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Exp"
)
{
reference
::
exp
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Floor"
)
{
reference
::
floor
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"FunctionCall"
)
{
std
::
shared_ptr
<
Function
>
function
=
node
.
get_functions
()[
0
];
call
(
function
,
outputs
,
inputs
);
}
else
if
(
node_op
==
"Greater"
)
{
reference
::
greater
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"GreaterEq"
)
{
reference
::
greater_eq
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Less"
)
{
reference
::
less
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"LessEq"
)
{
reference
::
less_eq
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Log"
)
{
reference
::
log
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Max"
)
{
const
op
::
Max
*
max
=
static_cast
<
const
op
::
Max
*>
(
&
node
);
reference
::
max
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
max
->
get_reduction_axes
());
}
else
if
(
node_op
==
"Maximum"
)
{
reference
::
maximum
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"MaxPool"
)
{
op
::
MaxPool
*
max_pool
=
dynamic_cast
<
op
::
MaxPool
*>
(
&
node
);
std
::
vector
<
std
::
shared_ptr
<
runtime
::
TensorView
>>
outputs
;
for
(
auto
tv
:
out
)
{
outputs
.
push_back
(
std
::
static_pointer_cast
<
runtime
::
TensorView
>
(
tv
));
}
reference
::
max_pool
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
max_pool
->
get_window_shape
(),
max_pool
->
get_window_movement_strides
(),
max_pool
->
get_padding_below
(),
max_pool
->
get_padding_above
());
}
else
if
(
node_op
==
"MaxPoolBackprop"
)
{
op
::
MaxPoolBackprop
*
max_pool_backprop
=
dynamic_cast
<
op
::
MaxPoolBackprop
*>
(
&
node
);
std
::
vector
<
std
::
shared_ptr
<
runtime
::
TensorView
>>
inputs
;
for
(
auto
tv
:
args
)
{
inputs
.
push_back
(
std
::
static_pointer_cast
<
runtime
::
TensorView
>
(
tv
));
}
reference
::
max_pool_backprop
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
1
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
max_pool_backprop
->
get_window_shape
(),
max_pool_backprop
->
get_window_movement_strides
(),
max_pool_backprop
->
get_padding_below
(),
max_pool_backprop
->
get_padding_above
());
}
else
if
(
node_op
==
"Min"
)
{
const
op
::
Min
*
min
=
static_cast
<
const
op
::
Min
*>
(
&
node
);
reference
::
min
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
min
->
get_reduction_axes
());
}
else
if
(
node_op
==
"Minimum"
)
{
reference
::
minimum
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Multiply"
)
{
reference
::
multiply
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Negative"
)
{
reference
::
negate
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Not"
)
{
reference
::
logical_not
(
reinterpret_cast
<
char
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"NotEqual"
)
{
reference
::
not_equal
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"OneHot"
)
{
auto
oh
=
static_cast
<
const
op
::
OneHot
*>
(
&
node
);
reference
::
one_hot
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
oh
->
get_one_hot_axis
());
}
else
if
(
node_op
==
"Or"
)
{
reference
::
logical_or
(
reinterpret_cast
<
char
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Parameter"
)
{
}
else
if
(
node_op
==
"Pad"
)
{
op
::
Pad
*
pad
=
dynamic_cast
<
op
::
Pad
*>
(
&
node
);
call
(
function
,
outputs
,
inputs
);
}
else
if
(
node_op
==
"Greater"
)
{
reference
::
greater
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"GreaterEq"
)
{
reference
::
greater_eq
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Less"
)
{
reference
::
less
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"LessEq"
)
{
reference
::
less_eq
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Log"
)
{
reference
::
log
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Max"
)
{
const
op
::
Max
*
max
=
static_cast
<
const
op
::
Max
*>
(
&
node
);
reference
::
max
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
max
->
get_reduction_axes
());
}
else
if
(
node_op
==
"Maximum"
)
{
reference
::
maximum
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"MaxPool"
)
{
op
::
MaxPool
*
max_pool
=
dynamic_cast
<
op
::
MaxPool
*>
(
&
node
);
reference
::
pad
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
node
.
get_inputs
().
at
(
0
).
get_shape
(),
node
.
get_output_shape
(
0
),
pad
->
get_padding_below
(),
pad
->
get_padding_above
(),
pad
->
get_padding_interior
());
}
else
if
(
node_op
==
"Power"
)
{
reference
::
power
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reference
::
max_pool
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
max_pool
->
get_window_shape
(),
max_pool
->
get_window_movement_strides
(),
max_pool
->
get_padding_below
(),
max_pool
->
get_padding_above
());
}
else
if
(
node_op
==
"MaxPoolBackprop"
)
{
op
::
MaxPoolBackprop
*
max_pool_backprop
=
dynamic_cast
<
op
::
MaxPoolBackprop
*>
(
&
node
);
reference
::
max_pool_backprop
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Product"
)
{
const
op
::
Product
*
product
=
static_cast
<
const
op
::
Product
*>
(
&
node
);
reference
::
product
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
product
->
get_reduction_axes
());
}
else
if
(
node_op
==
"Reduce"
)
{
op
::
Reduce
*
reduce
=
dynamic_cast
<
op
::
Reduce
*>
(
&
node
);
std
::
shared_ptr
<
Function
>
reduction_function
=
reduce
->
get_functions
()[
0
];
args
[
1
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
max_pool_backprop
->
get_window_shape
(),
max_pool_backprop
->
get_window_movement_strides
(),
max_pool_backprop
->
get_padding_below
(),
max_pool_backprop
->
get_padding_above
());
}
else
if
(
node_op
==
"Min"
)
{
const
op
::
Min
*
min
=
static_cast
<
const
op
::
Min
*>
(
&
node
);
reference
::
min
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
min
->
get_reduction_axes
());
}
else
if
(
node_op
==
"Minimum"
)
{
reference
::
minimum
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Multiply"
)
{
reference
::
multiply
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Negative"
)
{
reference
::
negate
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Not"
)
{
reference
::
logical_not
(
reinterpret_cast
<
char
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"NotEqual"
)
{
reference
::
not_equal
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"OneHot"
)
{
auto
oh
=
static_cast
<
const
op
::
OneHot
*>
(
&
node
);
reference
::
one_hot
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
oh
->
get_one_hot_axis
());
}
else
if
(
node_op
==
"Or"
)
{
reference
::
logical_or
(
reinterpret_cast
<
char
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
char
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Parameter"
)
{
}
else
if
(
node_op
==
"Pad"
)
{
op
::
Pad
*
pad
=
dynamic_cast
<
op
::
Pad
*>
(
&
node
);
reference
::
pad
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
node
.
get_inputs
().
at
(
0
).
get_shape
(),
node
.
get_output_shape
(
0
),
pad
->
get_padding_below
(),
pad
->
get_padding_above
(),
pad
->
get_padding_interior
());
}
else
if
(
node_op
==
"Power"
)
{
reference
::
power
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Product"
)
{
const
op
::
Product
*
product
=
static_cast
<
const
op
::
Product
*>
(
&
node
);
reference
::
product
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
product
->
get_reduction_axes
());
}
else
if
(
node_op
==
"Reduce"
)
{
op
::
Reduce
*
reduce
=
dynamic_cast
<
op
::
Reduce
*>
(
&
node
);
std
::
shared_ptr
<
Function
>
reduction_function
=
reduce
->
get_functions
()[
0
];
std
::
function
<
T
(
T
,
T
)
>
f
=
[
this
,
&
node
,
reduction_function
](
T
x
,
T
y
)
->
T
{
auto
tx
=
std
::
make_shared
<
HostTensorView
>
(
node
.
get_inputs
().
at
(
0
).
get_element_type
(),
Shape
{},
"reduce_temp_x"
);
auto
ty
=
std
::
make_shared
<
HostTensorView
>
(
node
.
get_inputs
().
at
(
1
).
get_element_type
(),
Shape
{},
"reduce_temp_y"
);
auto
tr
=
std
::
make_shared
<
HostTensorView
>
(
node
.
get_output_element_type
(
0
),
Shape
{},
"reduce_temp_r"
);
*
(
reinterpret_cast
<
T
*>
(
tx
->
get_data_ptr
()))
=
x
;
*
(
reinterpret_cast
<
T
*>
(
ty
->
get_data_ptr
()))
=
y
;
call
(
reduction_function
,
{
tr
},
{
tx
,
ty
});
return
*
(
reinterpret_cast
<
T
*>
(
tr
->
get_data_ptr
()));
};
std
::
function
<
T
(
T
,
T
)
>
f
=
[
this
,
&
node
,
reduction_function
](
T
x
,
T
y
)
->
T
{
auto
tx
=
std
::
make_shared
<
HostTensorView
>
(
node
.
get_inputs
().
at
(
0
).
get_element_type
(),
Shape
{},
"reduce_temp_x"
);
auto
ty
=
std
::
make_shared
<
HostTensorView
>
(
node
.
get_inputs
().
at
(
1
).
get_element_type
(),
Shape
{},
"reduce_temp_y"
);
auto
tr
=
std
::
make_shared
<
HostTensorView
>
(
node
.
get_output_element_type
(
0
),
Shape
{},
"reduce_temp_r"
);
*
(
reinterpret_cast
<
T
*>
(
tx
->
get_data_ptr
()))
=
x
;
*
(
reinterpret_cast
<
T
*>
(
ty
->
get_data_ptr
()))
=
y
;
call
(
reduction_function
,
{
tr
},
{
tx
,
ty
});
return
*
(
reinterpret_cast
<
T
*>
(
tr
->
get_data_ptr
()));
};
reference
::
reduce
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
node
.
get_inputs
().
at
(
0
).
get_shape
(),
node
.
get_output_shape
(
0
),
reduce
->
get_reduction_axes
(),
f
);
}
else
if
(
node_op
==
"ReduceWindow"
)
{
op
::
ReduceWindow
*
reduce_window
=
dynamic_cast
<
op
::
ReduceWindow
*>
(
&
node
);
std
::
shared_ptr
<
Function
>
reduction_function
=
reduce_window
->
get_functions
()[
0
];
reference
::
reduce
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
node
.
get_inputs
().
at
(
0
).
get_shape
(),
node
.
get_output_shape
(
0
),
reduce
->
get_reduction_axes
(),
f
);
}
else
if
(
node_op
==
"ReduceWindow"
)
{
op
::
ReduceWindow
*
reduce_window
=
dynamic_cast
<
op
::
ReduceWindow
*>
(
&
node
);
std
::
shared_ptr
<
Function
>
reduction_function
=
reduce_window
->
get_functions
()[
0
];
std
::
function
<
T
(
T
,
T
)
>
f
=
[
this
,
&
node
,
reduction_function
](
T
x
,
T
y
)
->
T
{
auto
tx
=
std
::
make_shared
<
HostTensorView
>
(
node
.
get_inputs
().
at
(
0
).
get_element_type
(),
Shape
{},
"reduce_window_temp_x"
);
auto
ty
=
std
::
make_shared
<
HostTensorView
>
(
node
.
get_inputs
().
at
(
1
).
get_element_type
(),
Shape
{},
"reduce_window_temp_y"
);
auto
tr
=
std
::
make_shared
<
HostTensorView
>
(
node
.
get_output_element_type
(
0
),
Shape
{},
"reduce_window_temp_r"
);
*
(
reinterpret_cast
<
T
*>
(
tx
->
get_data_ptr
()))
=
x
;
*
(
reinterpret_cast
<
T
*>
(
ty
->
get_data_ptr
()))
=
y
;
call
(
reduction_function
,
{
tr
},
{
tx
,
ty
});
return
*
(
reinterpret_cast
<
T
*>
(
tr
->
get_data_ptr
()));
};
std
::
function
<
T
(
T
,
T
)
>
f
=
[
this
,
&
node
,
reduction_function
](
T
x
,
T
y
)
->
T
{
auto
tx
=
std
::
make_shared
<
HostTensorView
>
(
node
.
get_inputs
().
at
(
0
).
get_element_type
(),
Shape
{},
"reduce_window_temp_x"
);
auto
ty
=
std
::
make_shared
<
HostTensorView
>
(
node
.
get_inputs
().
at
(
1
).
get_element_type
(),
Shape
{},
"reduce_window_temp_y"
);
auto
tr
=
std
::
make_shared
<
HostTensorView
>
(
node
.
get_output_element_type
(
0
),
Shape
{},
"reduce_window_temp_r"
);
*
(
reinterpret_cast
<
T
*>
(
tx
->
get_data_ptr
()))
=
x
;
*
(
reinterpret_cast
<
T
*>
(
ty
->
get_data_ptr
()))
=
y
;
call
(
reduction_function
,
{
tr
},
{
tx
,
ty
});
return
*
(
reinterpret_cast
<
T
*>
(
tr
->
get_data_ptr
()));
};
reference
::
reduce_window
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
node
.
get_inputs
().
at
(
0
).
get_shape
(),
node
.
get_output_shape
(
0
),
f
,
reduce_window
->
get_window_shape
(),
reduce_window
->
get_window_movement_strides
());
}
else
if
(
node_op
==
"Relu"
)
{
reference
::
relu
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"ReluBackprop"
)
{
reference
::
relu_backprop
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"ReplaceSlice"
)
{
const
op
::
ReplaceSlice
*
slice
=
static_cast
<
const
op
::
ReplaceSlice
*>
(
&
node
);
reference
::
replace_slice
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
1
]
->
get_shape
(),
slice
->
get_lower_bounds
(),
slice
->
get_upper_bounds
(),
slice
->
get_strides
(),
out
[
0
]
->
get_shape
());
}
else
if
(
node_op
==
"Reshape"
)
{
op
::
Reshape
*
reshape
=
dynamic_cast
<
op
::
Reshape
*>
(
&
node
);
reference
::
reshape
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
reshape
->
get_input_order
(),
out
[
0
]
->
get_shape
());
}
else
if
(
node_op
==
"Result"
)
{
op
::
Result
*
res
=
dynamic_cast
<
op
::
Result
*>
(
&
node
);
reference
::
result
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
shape_size
(
res
->
get_shape
()));
}
else
if
(
node_op
==
"Reverse"
)
{
op
::
Reverse
*
reverse
=
dynamic_cast
<
op
::
Reverse
*>
(
&
node
);
reference
::
reverse
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
reverse
->
get_reversed_axes
());
}
else
if
(
node_op
==
"Sign"
)
{
reference
::
sign
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Sin"
)
{
reference
::
sin
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Sinh"
)
{
reference
::
sinh
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Slice"
)
{
const
op
::
Slice
*
slice
=
static_cast
<
const
op
::
Slice
*>
(
&
node
);
reference
::
slice
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
slice
->
get_lower_bounds
(),
slice
->
get_upper_bounds
(),
slice
->
get_strides
(),
out
[
0
]
->
get_shape
());
}
else
if
(
node_op
==
"Softmax"
)
{
const
op
::
Softmax
*
softmax
=
static_cast
<
const
op
::
Softmax
*>
(
&
node
);
reference
::
softmax
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_shape
(),
softmax
->
get_axes
());
}
else
if
(
node_op
==
"Sqrt"
)
{
reference
::
sqrt
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Subtract"
)
{
reference
::
subtract
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Sum"
)
{
const
op
::
Sum
*
sum
=
static_cast
<
const
op
::
Sum
*>
(
&
node
);
reference
::
sum
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
sum
->
get_reduction_axes
());
}
else
if
(
node_op
==
"Tan"
)
{
reference
::
tan
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Tanh"
)
{
reference
::
tanh
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
{
std
::
stringstream
ss
;
ss
<<
"unsupported op "
<<
node_op
;
throw
ngraph_error
(
ss
.
str
());
}
}
reference
::
reduce_window
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
node
.
get_inputs
().
at
(
0
).
get_shape
(),
node
.
get_output_shape
(
0
),
f
,
reduce_window
->
get_window_shape
(),
reduce_window
->
get_window_movement_strides
());
}
else
if
(
node_op
==
"Relu"
)
{
reference
::
relu
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"ReluBackprop"
)
{
reference
::
relu_backprop
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"ReplaceSlice"
)
{
const
op
::
ReplaceSlice
*
slice
=
static_cast
<
const
op
::
ReplaceSlice
*>
(
&
node
);
reference
::
replace_slice
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
1
]
->
get_shape
(),
slice
->
get_lower_bounds
(),
slice
->
get_upper_bounds
(),
slice
->
get_strides
(),
out
[
0
]
->
get_shape
());
}
else
if
(
node_op
==
"Reshape"
)
{
op
::
Reshape
*
reshape
=
dynamic_cast
<
op
::
Reshape
*>
(
&
node
);
reference
::
reshape
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
reshape
->
get_input_order
(),
out
[
0
]
->
get_shape
());
}
else
if
(
node_op
==
"Result"
)
{
op
::
Result
*
res
=
dynamic_cast
<
op
::
Result
*>
(
&
node
);
reference
::
result
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
shape_size
(
res
->
get_shape
()));
}
else
if
(
node_op
==
"Reverse"
)
{
op
::
Reverse
*
reverse
=
dynamic_cast
<
op
::
Reverse
*>
(
&
node
);
reference
::
reverse
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
reverse
->
get_reversed_axes
());
}
else
if
(
node_op
==
"Select"
)
{
reference
::
select
<
T
>
(
reinterpret_cast
<
char
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
2
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"SelectAndScatter"
)
{
ngraph
::
op
::
SelectAndScatter
*
select_and_scatter
=
dynamic_cast
<
ngraph
::
op
::
SelectAndScatter
*>
(
&
node
);
std
::
shared_ptr
<
ngraph
::
Function
>
selection_function
=
select_and_scatter
->
get_functions
()[
0
];
std
::
function
<
bool
(
T
,
T
)
>
f_selection
=
[
this
,
&
node
,
selection_function
](
T
x
,
T
y
)
->
bool
{
auto
tx
=
std
::
make_shared
<
runtime
::
HostTensorView
>
(
node
.
get_inputs
().
at
(
0
).
get_element_type
(),
Shape
{},
"selection_temp_x"
);
auto
ty
=
std
::
make_shared
<
runtime
::
HostTensorView
>
(
node
.
get_inputs
().
at
(
1
).
get_element_type
(),
Shape
{},
"selection_temp_y"
);
auto
tr
=
std
::
make_shared
<
runtime
::
HostTensorView
>
(
element
::
boolean
,
Shape
{},
"selection_temp_r"
);
*
(
reinterpret_cast
<
T
*>
(
tx
->
get_data_ptr
()))
=
x
;
*
(
reinterpret_cast
<
T
*>
(
ty
->
get_data_ptr
()))
=
y
;
call
(
selection_function
,
{
tr
},
{
tx
,
ty
});
return
*
(
reinterpret_cast
<
char
*>
(
tr
->
get_data_ptr
()));
};
std
::
shared_ptr
<
ngraph
::
Function
>
scatter_function
=
select_and_scatter
->
get_functions
()[
1
];
std
::
function
<
T
(
T
,
T
)
>
f_scatter
=
[
this
,
&
node
,
scatter_function
](
T
x
,
T
y
)
->
T
{
auto
tx
=
std
::
make_shared
<
runtime
::
HostTensorView
>
(
node
.
get_inputs
().
at
(
0
).
get_element_type
(),
Shape
{},
"scatter_temp_x"
);
auto
ty
=
std
::
make_shared
<
runtime
::
HostTensorView
>
(
node
.
get_inputs
().
at
(
1
).
get_element_type
(),
Shape
{},
"scatter_temp_y"
);
auto
tr
=
std
::
make_shared
<
runtime
::
HostTensorView
>
(
node
.
get_output_element_type
(
0
),
Shape
{},
"scatter_temp_r"
);
*
(
reinterpret_cast
<
T
*>
(
tx
->
get_data_ptr
()))
=
x
;
*
(
reinterpret_cast
<
T
*>
(
ty
->
get_data_ptr
()))
=
y
;
call
(
scatter_function
,
{
tr
},
{
tx
,
ty
});
return
*
(
reinterpret_cast
<
T
*>
(
tr
->
get_data_ptr
()));
};
reference
::
select_and_scatter
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
2
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
args
[
1
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
f_selection
,
f_scatter
,
select_and_scatter
->
get_window_shape
(),
select_and_scatter
->
get_window_movement_strides
());
}
else
if
(
node_op
==
"Sign"
)
{
reference
::
sign
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Sin"
)
{
reference
::
sin
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Sinh"
)
{
reference
::
sinh
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Slice"
)
{
const
op
::
Slice
*
slice
=
static_cast
<
const
op
::
Slice
*>
(
&
node
);
reference
::
slice
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
slice
->
get_lower_bounds
(),
slice
->
get_upper_bounds
(),
slice
->
get_strides
(),
out
[
0
]
->
get_shape
());
}
else
if
(
node_op
==
"Softmax"
)
{
const
op
::
Softmax
*
softmax
=
static_cast
<
const
op
::
Softmax
*>
(
&
node
);
reference
::
softmax
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_shape
(),
softmax
->
get_axes
());
}
else
if
(
node_op
==
"Sqrt"
)
{
reference
::
sqrt
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Subtract"
)
{
reference
::
subtract
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
args
[
1
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Sum"
)
{
const
op
::
Sum
*
sum
=
static_cast
<
const
op
::
Sum
*>
(
&
node
);
reference
::
sum
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
args
[
0
]
->
get_shape
(),
out
[
0
]
->
get_shape
(),
sum
->
get_reduction_axes
());
}
else
if
(
node_op
==
"Tan"
)
{
reference
::
tan
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
if
(
node_op
==
"Tanh"
)
{
reference
::
tanh
<
T
>
(
reinterpret_cast
<
T
*>
(
args
[
0
]
->
get_data_ptr
()),
reinterpret_cast
<
T
*>
(
out
[
0
]
->
get_data_ptr
()),
out
[
0
]
->
get_element_count
());
}
else
{
std
::
stringstream
ss
;
ss
<<
"unsupported op "
<<
node_op
;
throw
ngraph_error
(
ss
.
str
());
}
}
}
}
;
src/ngraph/runtime/reference/select.hpp
View file @
1eb9f9bf
...
...
@@ -17,6 +17,7 @@
#pragma once
#include <cstddef>
#include <iostream>
namespace
ngraph
{
...
...
test/autodiff.in.cpp
View file @
1eb9f9bf
...
...
@@ -829,7 +829,6 @@ TEST(${BACKEND_NAME}, backwards_log)
TEST
(
$
{
BACKEND_NAME
},
backwards_maximum
)
{
SKIP_TEST_FOR
(
"IE"
,
"${BACKEND_NAME}"
);
// no convert support
auto
backend
=
runtime
::
Backend
::
create
(
"${BACKEND_NAME}"
);
test
::
Uniform
<
float
>
rng
(
-
1.0
f
,
1.0
f
);
...
...
@@ -848,7 +847,6 @@ TEST(${BACKEND_NAME}, backwards_maximum)
TEST
(
$
{
BACKEND_NAME
},
backwards_minimum
)
{
SKIP_TEST_FOR
(
"IE"
,
"${BACKEND_NAME}"
);
// no convert support
auto
backend
=
runtime
::
Backend
::
create
(
"${BACKEND_NAME}"
);
test
::
Uniform
<
float
>
rng
(
-
1.0
f
,
1.0
f
);
...
...
@@ -1019,7 +1017,6 @@ TEST(${BACKEND_NAME}, backwards_reshape)
TEST
(
$
{
BACKEND_NAME
},
backwards_select
)
{
SKIP_TEST_FOR
(
"IE"
,
"${BACKEND_NAME}"
);
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
SKIP_TEST_FOR
(
"NNP_TESTER"
,
"${BACKEND_NAME}"
);
...
...
@@ -1049,7 +1046,6 @@ TEST(${BACKEND_NAME}, backwards_select)
TEST
(
$
{
BACKEND_NAME
},
backwards_select_nested
)
{
SKIP_TEST_FOR
(
"IE"
,
"${BACKEND_NAME}"
);
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
SKIP_TEST_FOR
(
"NNP_TESTER"
,
"${BACKEND_NAME}"
);
...
...
test/backend_test.in.cpp
View file @
1eb9f9bf
...
...
@@ -1349,7 +1349,6 @@ TEST(${BACKEND_NAME}, notequal)
TEST
(
$
{
BACKEND_NAME
},
select
)
{
SKIP_TEST_FOR
(
"IE"
,
"${BACKEND_NAME}"
);
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
Shape
shape
{
2
,
2
,
2
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
boolean
,
shape
);
...
...
@@ -1767,7 +1766,6 @@ TEST(${BACKEND_NAME}, broadcast_matrix_2)
TEST
(
$
{
BACKEND_NAME
},
convert_int32_float32
)
{
SKIP_TEST_FOR
(
"IE"
,
"${BACKEND_NAME}"
);
Shape
shape
{
2
,
2
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
shape
);
auto
f
=
...
...
@@ -1786,7 +1784,6 @@ TEST(${BACKEND_NAME}, convert_int32_float32)
TEST
(
$
{
BACKEND_NAME
},
convert_int32_bool
)
{
SKIP_TEST_FOR
(
"IE"
,
"${BACKEND_NAME}"
);
Shape
shape
{
2
,
2
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
i32
,
shape
);
auto
f
=
make_shared
<
Function
>
(
make_shared
<
op
::
Convert
>
(
A
,
element
::
boolean
),
...
...
@@ -1805,7 +1802,6 @@ TEST(${BACKEND_NAME}, convert_int32_bool)
TEST
(
$
{
BACKEND_NAME
},
convert_float32_bool
)
{
SKIP_TEST_FOR
(
"IE"
,
"${BACKEND_NAME}"
);
Shape
shape
{
2
,
2
};
auto
A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape
);
auto
f
=
make_shared
<
Function
>
(
make_shared
<
op
::
Convert
>
(
A
,
element
::
boolean
),
...
...
@@ -5148,7 +5144,6 @@ TEST(${BACKEND_NAME}, reduce_window_emulating_max_pool_2d_1channel_1image_stride
//
TEST
(
$
{
BACKEND_NAME
},
select_and_scatter_with_overlap
)
{
SKIP_TEST_FOR
(
"IE"
,
"${BACKEND_NAME}"
);
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
Shape
shape_sel_a
{};
auto
SEL_A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape_sel_a
);
...
...
@@ -5203,7 +5198,6 @@ TEST(${BACKEND_NAME}, select_and_scatter_with_overlap)
//
TEST
(
$
{
BACKEND_NAME
},
select_and_scatter_without_overlap
)
{
SKIP_TEST_FOR
(
"IE"
,
"${BACKEND_NAME}"
);
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
Shape
shape_sel_a
{};
auto
SEL_A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape_sel_a
);
...
...
@@ -5258,7 +5252,6 @@ TEST(${BACKEND_NAME}, select_and_scatter_without_overlap)
//
TEST
(
$
{
BACKEND_NAME
},
select_and_scatter_3d_without_overlap
)
{
SKIP_TEST_FOR
(
"IE"
,
"${BACKEND_NAME}"
);
SKIP_TEST_FOR
(
"GPU"
,
"${BACKEND_NAME}"
);
Shape
shape_sel_a
{};
auto
SEL_A
=
make_shared
<
op
::
Parameter
>
(
element
::
f32
,
shape_sel_a
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment