Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
47342339
Unverified
Commit
47342339
authored
Jul 09, 2019
by
Scott Cyphers
Committed by
GitHub
Jul 09, 2019
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #3178 from NervanaSystems/bob/gcpu
Update generic CPU backend to latest ngraph API
parents
30527e80
c7630c05
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
703 additions
and
400 deletions
+703
-400
CMakeLists.txt
src/ngraph/runtime/generic_cpu/CMakeLists.txt
+4
-4
gcpu_backend.cpp
src/ngraph/runtime/generic_cpu/gcpu_backend.cpp
+2
-2
gcpu_executable.cpp
src/ngraph/runtime/generic_cpu/gcpu_executable.cpp
+53
-36
gcpu_executable.hpp
src/ngraph/runtime/generic_cpu/gcpu_executable.hpp
+547
-312
broadcast.hpp
src/ngraph/runtime/generic_cpu/kernel/broadcast.hpp
+95
-0
reshape.hpp
src/ngraph/runtime/generic_cpu/kernel/reshape.hpp
+1
-4
result.hpp
src/ngraph/runtime/generic_cpu/kernel/result.hpp
+0
-41
node_wrapper.hpp
src/ngraph/runtime/generic_cpu/node_wrapper.hpp
+1
-1
No files found.
src/ngraph/runtime/generic_cpu/CMakeLists.txt
View file @
47342339
...
...
@@ -15,10 +15,10 @@
# ******************************************************************************
if
(
NGRAPH_GENERIC_CPU_ENABLE
)
find_package
(
OpenMP
)
if
(
OPENMP_FOUND
)
add_compile_options
(
${
OpenMP_CXX_FLAGS
}
)
endif
()
#
find_package(OpenMP)
#
if (OPENMP_FOUND)
#
add_compile_options(${OpenMP_CXX_FLAGS})
#
endif()
add_library
(
gcpu_backend SHARED gcpu_backend.cpp gcpu_executable.cpp node_wrapper.cpp
)
if
(
NGRAPH_LIB_VERSIONING_ENABLE
)
set_target_properties
(
gcpu_backend PROPERTIES
...
...
src/ngraph/runtime/generic_cpu/gcpu_backend.cpp
View file @
47342339
...
...
@@ -52,14 +52,14 @@ runtime::gcpu::GCPUBackend::GCPUBackend(const vector<string>& unsupported_op_nam
shared_ptr
<
runtime
::
Tensor
>
runtime
::
gcpu
::
GCPUBackend
::
create_tensor
(
const
element
::
Type
&
type
,
const
Shape
&
shape
)
{
return
make_shared
<
runtime
::
HostTensor
>
(
type
,
shape
,
this
);
return
make_shared
<
runtime
::
HostTensor
>
(
type
,
shape
);
}
shared_ptr
<
runtime
::
Tensor
>
runtime
::
gcpu
::
GCPUBackend
::
create_tensor
(
const
element
::
Type
&
type
,
const
Shape
&
shape
,
void
*
memory_pointer
)
{
return
make_shared
<
runtime
::
HostTensor
>
(
type
,
shape
,
memory_pointer
,
this
);
return
make_shared
<
runtime
::
HostTensor
>
(
type
,
shape
,
memory_pointer
);
}
shared_ptr
<
runtime
::
Executable
>
...
...
src/ngraph/runtime/generic_cpu/gcpu_executable.cpp
View file @
47342339
...
...
@@ -15,17 +15,22 @@
//*****************************************************************************
#include "ngraph/runtime/generic_cpu/gcpu_executable.hpp"
#include "ngraph/cpio.hpp"
#include "ngraph/descriptor/layout/dense_tensor_layout.hpp"
#include "ngraph/except.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/util/binary_elementwise_comparison.hpp"
#include "ngraph/pass/assign_layout.hpp"
#include "ngraph/pass/core_fusion.hpp"
#include "ngraph/pass/fused_op_decomposition.hpp"
#include "ngraph/pass/implicit_broadcast_elimination.hpp"
#include "ngraph/pass/like_replacement.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp"
#include "ngraph/runtime/backend_manager.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
using
namespace
std
;
...
...
@@ -35,21 +40,35 @@ using descriptor::layout::DenseTensorLayout;
runtime
::
gcpu
::
GCPUExecutable
::
GCPUExecutable
(
const
shared_ptr
<
Function
>&
function
,
bool
enable_performance_collection
)
:
m_is_compiled
{
true
}
,
m_performance_counters_enabled
{
enable_performance_collection
}
{
m_function
=
clone_function
(
*
function
);
pass
::
Manager
pass_manager
;
pass_manager
.
register_pass
<
pass
::
LikeReplacement
>
();
pass_manager
.
register_pass
<
pass
::
FusedOpDecomposition
>
();
pass_manager
.
register_pass
<
pass
::
ImplicitBroadcastElimination
>
();
pass_manager
.
register_pass
<
pass
::
AssignLayout
<
DenseTensorLayout
>>
();
pass_manager
.
register_pass
<
pass
::
Liveness
>
();
pass_manager
.
run_passes
(
m_function
);
for
(
const
shared_ptr
<
Node
>&
node
:
m_function
->
get_ordered_ops
())
{
m_is_compiled
=
true
;
pass
::
Manager
pass_manager
;
pass_manager
.
register_pass
<
pass
::
LikeReplacement
>
();
pass_manager
.
register_pass
<
pass
::
AssignLayout
<
DenseTensorLayout
>>
();
pass_manager
.
register_pass
<
pass
::
Liveness
>
();
pass_manager
.
run_passes
(
function
);
m_wrapped_nodes
.
emplace_back
(
node
);
}
set_parameters_and_results
(
*
m_function
);
}
for
(
const
shared_ptr
<
Node
>&
node
:
function
->
get_ordered_ops
())
{
m_wrapped_nodes
.
emplace_back
(
node
);
}
runtime
::
gcpu
::
GCPUExecutable
::
GCPUExecutable
(
const
std
::
string
&
model_string
)
:
m_is_compiled
{
true
}
,
m_performance_counters_enabled
{
false
}
{
m_function
=
deserialize
(
model_string
);
for
(
const
shared_ptr
<
Node
>&
node
:
m_function
->
get_ordered_ops
())
{
m_wrapped_nodes
.
emplace_back
(
node
);
}
set_parameters_and_results
(
*
function
);
set_parameters_and_results
(
*
m_
function
);
}
bool
runtime
::
gcpu
::
GCPUExecutable
::
call
(
const
vector
<
shared_ptr
<
runtime
::
Tensor
>>&
outputs
,
...
...
@@ -82,7 +101,7 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
{
for
(
size_t
i
=
0
;
i
<
param
->
get_output_size
();
++
i
)
{
descriptor
::
Tensor
*
tensor
=
param
->
get_output_tensor_ptr
(
i
).
get
();
descriptor
::
Tensor
*
tensor
=
&
param
->
output
(
i
).
get_tensor
();
tensor_map
.
insert
({
tensor
,
func_inputs
[
input_count
++
]});
}
}
...
...
@@ -95,14 +114,14 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
{
throw
ngraph_error
(
"One of function's outputs isn't op::Result"
);
}
descriptor
::
Tensor
*
tensor
=
output
->
get_output_tensor_ptr
(
0
).
get
();
descriptor
::
Tensor
*
tensor
=
&
output
->
output
(
0
).
get_tensor
();
tensor_map
.
insert
({
tensor
,
func_outputs
[
output_count
]});
}
// for each ordered op in the graph
for
(
const
NodeWrapper
&
wrapped
:
m_wrapped_nodes
)
{
const
Node
*
op
=
&
wrapped
.
get_node
();
auto
op
=
wrapped
.
get_node
();
auto
type_id
=
wrapped
.
get_typeid
();
if
(
type_id
==
OP_TYPEID
::
Parameter
)
{
...
...
@@ -111,9 +130,9 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
// get op inputs from map
vector
<
shared_ptr
<
HostTensor
>>
op_inputs
;
for
(
const
descriptor
::
Input
&
input
:
op
->
get_
inputs
())
for
(
auto
input
:
op
->
inputs
())
{
descriptor
::
Tensor
*
tensor
=
input
.
get_output
().
get_tensor_ptr
().
get
();
descriptor
::
Tensor
*
tensor
=
&
input
.
get_tensor
();
op_inputs
.
push_back
(
tensor_map
.
at
(
tensor
));
}
...
...
@@ -121,14 +140,14 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
vector
<
shared_ptr
<
HostTensor
>>
op_outputs
;
for
(
size_t
i
=
0
;
i
<
op
->
get_output_size
();
++
i
)
{
descriptor
::
Tensor
*
tensor
=
op
->
get_output_tensor_ptr
(
i
).
get
();
descriptor
::
Tensor
*
tensor
=
&
op
->
output
(
i
).
get_tensor
();
shared_ptr
<
HostTensor
>
host_tensor
;
auto
it
=
tensor_map
.
find
(
tensor
);
if
(
it
==
tensor_map
.
end
())
{
const
Shape
&
shape
=
op
->
get_output_shape
(
i
);
const
element
::
Type
&
type
=
op
->
get_output_element_type
(
i
);
string
name
=
op
->
get_output_tensor
(
i
).
get_name
();
string
name
=
op
->
output
(
i
).
get_tensor
(
).
get_name
();
host_tensor
=
make_shared
<
runtime
::
HostTensor
>
(
type
,
shape
,
name
);
tensor_map
.
insert
({
tensor
,
host_tensor
});
}
...
...
@@ -177,7 +196,7 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
}
if
(
m_nan_check_enabled
)
{
perform_nan_check
(
op_outputs
,
op
);
perform_nan_check
(
op_outputs
,
op
.
get
()
);
}
}
...
...
@@ -186,19 +205,9 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
void
runtime
::
gcpu
::
GCPUExecutable
::
generate_calls
(
const
element
::
Type
&
type
,
const
NodeWrapper
&
op
,
const
vector
<
shared_ptr
<
HostTensor
>>&
out
puts
,
const
vector
<
shared_ptr
<
HostTensor
>>&
in
puts
)
const
vector
<
shared_ptr
<
HostTensor
>>&
out
,
const
vector
<
shared_ptr
<
HostTensor
>>&
in
)
{
vector
<
void
*>
out
;
vector
<
const
void
*>
in
;
for
(
auto
t
:
outputs
)
{
out
.
push_back
(
t
->
get_data_ptr
());
}
for
(
auto
t
:
inputs
)
{
in
.
push_back
(
t
->
get_data_ptr
());
}
stringstream
ss
;
switch
(
type
.
get_type_enum
())
{
...
...
@@ -216,7 +225,8 @@ void runtime::gcpu::GCPUExecutable::generate_calls(const element::Type& type,
case
element
:
:
Type_t
::
undefined
:
case
element
:
:
Type_t
::
dynamic
:
case
element
:
:
Type_t
::
bf16
:
ss
<<
"unsupported element type "
<<
type
<<
" op "
<<
op
.
get_node
().
get_name
();
case
element
:
:
Type_t
::
f16
:
ss
<<
"unsupported element type "
<<
type
<<
" op "
<<
op
.
get_node
()
->
get_name
();
throw
ngraph_error
(
ss
.
str
());
}
}
...
...
@@ -229,11 +239,9 @@ void runtime::gcpu::GCPUExecutable::set_nan_check(bool enable)
vector
<
runtime
::
PerformanceCounter
>
runtime
::
gcpu
::
GCPUExecutable
::
get_performance_data
()
const
{
vector
<
runtime
::
PerformanceCounter
>
rc
;
for
(
const
pair
<
const
Node
*
,
stopwatch
>
p
:
m_timer_map
)
for
(
const
pair
<
shared_ptr
<
const
Node
>
,
stopwatch
>
p
:
m_timer_map
)
{
rc
.
emplace_back
(
p
.
first
->
get_name
().
c_str
(),
p
.
second
.
get_total_microseconds
(),
p
.
second
.
get_call_count
());
rc
.
emplace_back
(
p
.
first
,
p
.
second
.
get_total_microseconds
(),
p
.
second
.
get_call_count
());
}
return
rc
;
}
...
...
@@ -286,3 +294,12 @@ void runtime::gcpu::GCPUExecutable::perform_nan_check(const vector<shared_ptr<Ho
arg_number
++
;
}
}
void
runtime
::
gcpu
::
GCPUExecutable
::
save
(
ostream
&
out
)
{
cpio
::
Writer
writer
(
out
);
string
si
=
"INTERPRETER Save File 1.0"
;
writer
.
write
(
"save_info"
,
si
.
data
(),
si
.
size
());
string
model
=
serialize
(
m_function
,
0
);
writer
.
write
(
"model"
,
model
.
data
(),
model
.
size
());
}
src/ngraph/runtime/generic_cpu/gcpu_executable.hpp
View file @
47342339
...
...
@@ -17,24 +17,31 @@
#pragma once
#include <initializer_list>
#include <iostream>
#include <memory>
#include <sstream>
#include <string>
#include <vector>
#include "ngraph/op/all.hpp"
#include "ngraph/op/allreduce.hpp"
#include "ngraph/op/any.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/broadcast_distributed.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/embedding_lookup.hpp"
#include "ngraph/op/experimental/batch_mat_mul.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_pad.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/gather.hpp"
...
...
@@ -48,11 +55,14 @@
#include "ngraph/op/passthrough.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/quantize.hpp"
#include "ngraph/op/quantized_convolution.hpp"
#include "ngraph/op/recv.hpp"
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/send.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/op/sum.hpp"
...
...
@@ -64,7 +74,6 @@
#include "ngraph/runtime/generic_cpu/kernel/reshape.hpp"
#include "ngraph/runtime/generic_cpu/node_wrapper.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/interpreter/node_wrapper.hpp"
#include "ngraph/runtime/reference/abs.hpp"
#include "ngraph/runtime/reference/acos.hpp"
#include "ngraph/runtime/reference/add.hpp"
...
...
@@ -77,7 +86,9 @@
#include "ngraph/runtime/reference/asin.hpp"
#include "ngraph/runtime/reference/atan.hpp"
#include "ngraph/runtime/reference/avg_pool.hpp"
#include "ngraph/runtime/reference/batch_mat_mul.hpp"
#include "ngraph/runtime/reference/batch_norm.hpp"
#include "ngraph/runtime/reference/broadcast.hpp"
#include "ngraph/runtime/reference/broadcast_distributed.hpp"
#include "ngraph/runtime/reference/ceiling.hpp"
#include "ngraph/runtime/reference/concat.hpp"
...
...
@@ -89,8 +100,10 @@
#include "ngraph/runtime/reference/cosh.hpp"
#include "ngraph/runtime/reference/dequantize.hpp"
#include "ngraph/runtime/reference/divide.hpp"
#include "ngraph/runtime/reference/dot.hpp"
#include "ngraph/runtime/reference/embedding_lookup.hpp"
#include "ngraph/runtime/reference/equal.hpp"
#include "ngraph/runtime/reference/erf.hpp"
#include "ngraph/runtime/reference/exp.hpp"
#include "ngraph/runtime/reference/floor.hpp"
#include "ngraph/runtime/reference/gather.hpp"
...
...
@@ -117,14 +130,17 @@
#include "ngraph/runtime/reference/power.hpp"
#include "ngraph/runtime/reference/product.hpp"
#include "ngraph/runtime/reference/quantize.hpp"
#include "ngraph/runtime/reference/recv.hpp"
#include "ngraph/runtime/reference/relu.hpp"
#include "ngraph/runtime/reference/replace_slice.hpp"
#include "ngraph/runtime/reference/reshape.hpp"
#include "ngraph/runtime/reference/result.hpp"
#include "ngraph/runtime/reference/reverse.hpp"
#include "ngraph/runtime/reference/reverse_sequence.hpp"
#include "ngraph/runtime/reference/scatter_add.hpp"
#include "ngraph/runtime/reference/scatter_nd_add.hpp"
#include "ngraph/runtime/reference/select.hpp"
#include "ngraph/runtime/reference/send.hpp"
#include "ngraph/runtime/reference/shape_of.hpp"
#include "ngraph/runtime/reference/sigmoid.hpp"
#include "ngraph/runtime/reference/sign.hpp"
...
...
@@ -134,6 +150,7 @@
#include "ngraph/runtime/reference/softmax.hpp"
#include "ngraph/runtime/reference/sqrt.hpp"
#include "ngraph/runtime/reference/subtract.hpp"
#include "ngraph/runtime/reference/sum.hpp"
#include "ngraph/runtime/reference/tan.hpp"
#include "ngraph/runtime/reference/tanh.hpp"
#include "ngraph/runtime/reference/topk.hpp"
...
...
@@ -154,6 +171,8 @@ namespace ngraph
class
ngraph
::
runtime
::
gcpu
::
GCPUExecutable
:
public
Executable
{
friend
class
GCPUBackend
;
public
:
GCPUExecutable
(
const
std
::
shared_ptr
<
Function
>&
function
,
bool
enable_performance_collection
=
false
);
...
...
@@ -161,20 +180,25 @@ public:
bool
call
(
const
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>&
outputs
,
const
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>&
intputs
)
override
;
virtual
void
save
(
std
::
ostream
&
output_stream
)
override
;
void
set_nan_check
(
bool
enable
);
std
::
vector
<
PerformanceCounter
>
get_performance_data
()
const
override
;
private
:
GCPUExecutable
(
const
std
::
string
&
model_string
);
int
get_alignment
()
const
{
return
64
;
}
bool
m_is_compiled
=
false
;
bool
m_nan_check_enabled
=
false
;
bool
m_performance_counters_enabled
=
false
;
std
::
unordered_map
<
const
Node
*
,
stopwatch
>
m_timer_map
;
std
::
shared_ptr
<
Function
>
m_function
;
std
::
unordered_map
<
std
::
shared_ptr
<
const
Node
>
,
stopwatch
>
m_timer_map
;
std
::
vector
<
NodeWrapper
>
m_wrapped_nodes
;
std
::
unordered_map
<
const
Node
*
,
std
::
shared_ptr
<
RNGState
>>
m_states
;
std
::
set
<
std
::
string
>
m_unsupported_op_name_list
;
int
get_alignment
()
const
{
return
64
;
}
static
void
perform_nan_check
(
const
std
::
vector
<
std
::
shared_ptr
<
HostTensor
>>&
,
const
Node
*
op
=
nullptr
);
...
...
@@ -185,11 +209,10 @@ private:
template
<
typename
T
>
void
op_engine
(
const
NodeWrapper
&
node_wrapper
,
const
std
::
vector
<
void
*
>&
out
,
const
std
::
vector
<
const
void
*
>&
args
)
const
std
::
vector
<
std
::
shared_ptr
<
HostTensor
>
>&
out
,
const
std
::
vector
<
std
::
shared_ptr
<
HostTensor
>
>&
args
)
{
const
Node
&
node
=
node_wrapper
.
get_node
();
std
::
string
node_op
=
node
.
description
();
const
Node
&
node
=
*
node_wrapper
.
get_node
();
// We want to check that every OP_TYPEID enumeration is included in the list.
// These GCC flags enable compile-time checking so that if an enumeration
...
...
@@ -206,30 +229,30 @@ private:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
abs
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Acos
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
acos
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Add
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
add
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
add
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
All
:
{
const
op
::
All
*
all
=
static_cast
<
const
op
::
All
*>
(
&
node
);
reference
::
all
(
static_cast
<
const
char
*>
(
args
[
0
]
),
static_cast
<
char
*>
(
out
[
0
]
),
reference
::
all
(
args
[
0
]
->
get_data_ptr
<
const
char
>
(
),
out
[
0
]
->
get_data_ptr
<
char
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
all
->
get_reduction_axes
());
...
...
@@ -237,26 +260,29 @@ private:
}
case
OP_TYPEID
:
:
AllReduce
:
{
reference
::
allreduce
<
T
>
(
static_cast
<
T
*>
(
const_cast
<
void
*>
(
args
[
0
])),
static_cast
<
T
*>
(
out
[
0
]),
node
.
get_input_element_type
(
0
),
const
ngraph
::
op
::
AllReduce
*
allreduce
=
static_cast
<
const
ngraph
::
op
::
AllReduce
*>
(
&
node
);
reference
::
allreduce
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
node
.
get_input_element_type
(
0
).
get_type_enum
(),
allreduce
->
get_reduce_type
(),
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
))));
break
;
}
case
OP_TYPEID
:
:
And
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
logical_and
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
logical_and
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Any
:
{
const
op
::
Any
*
any
=
static_cast
<
const
op
::
Any
*>
(
&
node
);
reference
::
any
(
static_cast
<
const
char
*>
(
args
[
0
]
),
static_cast
<
char
*>
(
out
[
0
]
),
reference
::
any
(
args
[
0
]
->
get_data_ptr
<
const
char
>
(
),
out
[
0
]
->
get_data_ptr
<
char
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
any
->
get_reduction_axes
());
...
...
@@ -268,16 +294,16 @@ private:
auto
element_type
=
node
.
get_output_element_type
(
0
);
if
(
element_type
==
element
::
i64
)
{
reference
::
argmin
<
T
,
int64_t
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
int64_t
*>
(
out
[
0
]
),
reference
::
argmin
<
T
,
int64_t
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
int64_t
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
argmin
->
get_reduction_axis
());
}
else
if
(
element_type
==
element
::
i32
)
{
reference
::
argmin
<
T
,
int32_t
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
int32_t
*>
(
out
[
0
]
),
reference
::
argmin
<
T
,
int32_t
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
int32_t
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
argmin
->
get_reduction_axis
());
...
...
@@ -294,16 +320,16 @@ private:
auto
element_type
=
node
.
get_output_element_type
(
0
);
if
(
element_type
==
element
::
i64
)
{
reference
::
argmax
<
T
,
int64_t
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
int64_t
*>
(
out
[
0
]
),
reference
::
argmax
<
T
,
int64_t
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
int64_t
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
argmax
->
get_reduction_axis
());
}
else
if
(
element_type
==
element
::
i32
)
{
reference
::
argmax
<
T
,
int32_t
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
int32_t
*>
(
out
[
0
]
),
reference
::
argmax
<
T
,
int32_t
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
int32_t
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
argmax
->
get_reduction_axis
());
...
...
@@ -318,22 +344,22 @@ private:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
asin
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Atan
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
atan
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
AvgPool
:
{
const
op
::
AvgPool
*
avg_pool
=
static_cast
<
const
op
::
AvgPool
*>
(
&
node
);
reference
::
avg_pool
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
avg_pool
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
avg_pool
->
get_window_shape
(),
...
...
@@ -345,18 +371,30 @@ private:
}
case
OP_TYPEID
:
:
GenerateMask
:
{
bool
use_seed
=
static_cast
<
bool
>
(
args
[
2
]
->
get_data_ptr
<
const
int32_t
>
()[
0
]);
if
(
m_states
.
count
(
&
node
)
==
0
)
{
const
op
::
GenerateMask
*
gm
=
static_cast
<
const
op
::
GenerateMask
*>
(
&
node
);
auto
seed
=
use_seed
?
gm
->
get_seed
()
:
0
;
m_states
[
&
node
]
=
std
::
unique_ptr
<
ngraph
::
RNGState
>
(
ngraph
::
RNGState
::
create_rng_state
(
gm
->
get_seed
()
,
gm
->
get_probability
()));
ngraph
::
RNGState
::
create_rng_state
(
seed
,
gm
->
get_probability
()));
}
bool
training
=
static_cast
<
bool
>
(
static_cast
<
const
T
*>
(
args
[
0
]
)[
0
]);
bool
training
=
static_cast
<
bool
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
)[
0
]);
auto
state
=
m_states
.
at
(
&
node
).
get
();
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
generate_mask
<
T
>
(
reinterpret_cast
<
T
*>
(
out
[
0
]),
element_count
,
state
,
training
);
if
(
!
use_seed
)
{
reference
::
generate_mask
<
T
>
(
out
[
0
]
->
get_data_ptr
<
T
>
(),
element_count
,
state
,
training
);
}
else
{
uint64_t
seed
=
static_cast
<
uint64_t
>
(
args
[
3
]
->
get_data_ptr
<
const
T
>
()[
0
]);
double
prob
=
static_cast
<
double
>
(
args
[
4
]
->
get_data_ptr
<
const
T
>
()[
0
]);
reference
::
generate_mask_no_state
<
T
>
(
out
[
0
]
->
get_data_ptr
<
T
>
(),
element_count
,
training
,
seed
,
prob
);
}
break
;
}
case
OP_TYPEID
:
:
GetOutputElement
:
...
...
@@ -366,20 +404,31 @@ private:
size_t
n
=
get_output_element
->
get_n
();
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
num_bytes
=
element_count
*
node
.
get_output_element_type
(
0
).
size
();
std
::
memcpy
(
static_cast
<
T
*>
(
out
[
0
]),
args
[
n
]
,
num_bytes
);
std
::
memcpy
(
out
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
n
]
->
get_data_ptr
<
T
>
()
,
num_bytes
);
break
;
}
case
OP_TYPEID
:
:
BatchMatMul
:
{
reference
::
batch_mat_mul
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
args
[
1
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_output_shape
(
0
));
break
;
}
case
OP_TYPEID
:
:
BatchNormTraining
:
{
const
ngraph
::
op
::
BatchNormTraining
*
bn
=
static_cast
<
const
ngraph
::
op
::
BatchNormTraining
*>
(
&
node
);
reference
::
batch_norm_training
<
T
>
(
bn
->
get_eps_value
(),
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
const
T
*>
(
args
[
2
]
),
static_cast
<
T
*>
(
out
[
0
]
),
static_cast
<
T
*>
(
out
[
1
]
),
static_cast
<
T
*>
(
out
[
2
]
),
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
args
[
2
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
out
[
1
]
->
get_data_ptr
<
T
>
(
),
out
[
2
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
2
));
break
;
}
...
...
@@ -388,12 +437,12 @@ private:
const
ngraph
::
op
::
BatchNormInference
*
bn
=
static_cast
<
const
ngraph
::
op
::
BatchNormInference
*>
(
&
node
);
reference
::
batch_norm_inference
<
T
>
(
bn
->
get_eps_value
(),
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
const
T
*>
(
args
[
2
]
),
static_cast
<
const
T
*>
(
args
[
3
]
),
static_cast
<
const
T
*>
(
args
[
4
]
),
static_cast
<
T
*>
(
out
[
0
]
),
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
args
[
2
]
->
get_data_ptr
<
const
T
>
(
),
args
[
3
]
->
get_data_ptr
<
const
T
>
(
),
args
[
4
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
2
));
break
;
}
...
...
@@ -402,23 +451,23 @@ private:
const
ngraph
::
op
::
BatchNormTrainingBackprop
*
bn_bprop
=
static_cast
<
const
ngraph
::
op
::
BatchNormTrainingBackprop
*>
(
&
node
);
reference
::
batch_norm_backprop
(
bn_bprop
->
get_eps_value
(),
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
const
T
*>
(
args
[
2
]
),
static_cast
<
const
T
*>
(
args
[
3
]
),
static_cast
<
const
T
*>
(
args
[
4
]
),
static_cast
<
const
T
*>
(
args
[
5
]
),
static_cast
<
T
*>
(
out
[
0
]
),
static_cast
<
T
*>
(
out
[
1
]
),
static_cast
<
T
*>
(
out
[
2
]
),
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
args
[
2
]
->
get_data_ptr
<
const
T
>
(
),
args
[
3
]
->
get_data_ptr
<
const
T
>
(
),
args
[
4
]
->
get_data_ptr
<
const
T
>
(
),
args
[
5
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
out
[
1
]
->
get_data_ptr
<
T
>
(
),
out
[
2
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
2
));
break
;
}
case
OP_TYPEID
:
:
AvgPoolBackprop
:
{
const
op
::
AvgPoolBackprop
*
apb
=
static_cast
<
const
op
::
AvgPoolBackprop
*>
(
&
node
);
reference
::
avg_pool_backprop
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
avg_pool_backprop
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
apb
->
get_window_shape
(),
...
...
@@ -434,32 +483,37 @@ private:
Shape
in_shape
=
node
.
get_input_shape
(
0
);
Shape
out_shape
=
node
.
get_output_shape
(
0
);
AxisSet
broadcast_axes
=
broadcast
->
get_broadcast_axes
();
gcpu
::
kernel
::
broadcast
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
in_shape
,
out_shape
,
broadcast_axes
);
kernel
::
broadcast
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
in_shape
,
out_shape
,
broadcast_axes
);
break
;
}
case
OP_TYPEID
:
:
BroadcastDistributed
:
{
int
rank_ID
=
get_distributed_interface
()
->
get_rank
();
if
(
rank_ID
==
0
)
const
ngraph
::
op
::
BroadcastDistributed
*
broadcast
=
static_cast
<
const
ngraph
::
op
::
BroadcastDistributed
*>
(
&
node
);
int
rank_ID
;
rank_ID
=
get_distributed_interface
()
->
get_rank
();
int
root_id
=
broadcast
->
get_root_id
();
if
(
rank_ID
==
root_id
)
{
reference
::
broadcastdistributed
<
T
>
(
static_cast
<
T
*>
(
args
[
0
]
),
node
.
get_input_element_type
(
0
),
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
)))
);
auto
memSize
=
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
)))
*
sizeof
(
node
.
get_input_element_type
(
0
)
);
memcpy
(
out
[
0
]
,
args
[
0
]
,
memSize
);
args
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_element_type
(
0
)
.
get_type_enum
()
,
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
)))
,
root_id
);
auto
memSize
=
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
)))
*
sizeof
(
T
);
memcpy
(
out
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
0
]
->
get_data_ptr
<
T
>
()
,
memSize
);
}
else
{
reference
::
broadcastdistributed
<
T
>
(
static_cast
<
T
*>
(
out
[
0
]),
node
.
get_input_element_type
(
0
),
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
))));
out
[
0
]
->
get_data_ptr
<
T
>
(),
node
.
get_input_element_type
(
0
).
get_type_enum
(),
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
))),
root_id
);
}
break
;
}
...
...
@@ -468,7 +522,7 @@ private:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
ceiling
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Concat
:
...
...
@@ -478,11 +532,11 @@ private:
std
::
vector
<
Shape
>
in_shapes
;
for
(
size_t
i
=
0
;
i
<
node
.
get_input_size
();
i
++
)
{
in_args
.
push_back
(
static_cast
<
const
T
*>
(
args
[
i
]
));
in_args
.
push_back
(
args
[
i
]
->
get_data_ptr
<
const
T
>
(
));
in_shapes
.
push_back
(
node
.
get_input_shape
(
i
));
}
reference
::
concat
<
T
>
(
in_args
,
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
in_shapes
,
node
.
get_output_shape
(
0
),
concat
->
get_concatenation_axis
());
...
...
@@ -492,7 +546,7 @@ private:
{
const
op
::
Constant
*
c
=
static_cast
<
const
op
::
Constant
*>
(
&
node
);
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
constant
<
T
>
(
c
->
get_data_ptr
<
T
>
(),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
reference
::
constant
<
T
>
(
c
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
ScalarConstantLike
:
break
;
...
...
@@ -505,52 +559,62 @@ private:
switch
(
type
.
get_type_enum
())
{
case
element
:
:
Type_t
::
boolean
:
reference
::
convert
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
char
*>
(
out
[
0
]
),
element_count
);
reference
::
convert
_to_bool
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
char
>
(
),
element_count
);
break
;
case
element
:
:
Type_t
::
f32
:
reference
::
convert
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
float
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
float
>
(
),
element_count
);
break
;
case
element
:
:
Type_t
::
f64
:
reference
::
convert
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
double
*>
(
out
[
0
]),
element_count
);
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
double
>
(),
element_count
);
break
;
case
element
:
:
Type_t
::
i8
:
reference
::
convert
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
int8_t
*>
(
out
[
0
]),
element_count
);
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
int8_t
>
(),
element_count
);
break
;
case
element
:
:
Type_t
::
i16
:
reference
::
convert
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
int16_t
*>
(
out
[
0
]),
element_count
);
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
int16_t
>
(),
element_count
);
break
;
case
element
:
:
Type_t
::
i32
:
reference
::
convert
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
int32_t
*>
(
out
[
0
]),
element_count
);
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
int32_t
>
(),
element_count
);
break
;
case
element
:
:
Type_t
::
i64
:
reference
::
convert
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
int64_t
*>
(
out
[
0
]),
element_count
);
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
int64_t
>
(),
element_count
);
break
;
case
element
:
:
Type_t
::
u8
:
reference
::
convert
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
uint8_t
*>
(
out
[
0
]),
element_count
);
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
uint8_t
>
(),
element_count
);
break
;
case
element
:
:
Type_t
::
u16
:
reference
::
convert
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
uint16_t
*>
(
out
[
0
]),
element_count
);
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
uint16_t
>
(),
element_count
);
break
;
case
element
:
:
Type_t
::
u32
:
reference
::
convert
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
uint32_t
*>
(
out
[
0
]),
element_count
);
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
uint32_t
>
(),
element_count
);
break
;
case
element
:
:
Type_t
::
u64
:
reference
::
convert
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
uint64_t
*>
(
out
[
0
]),
element_count
);
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
uint64_t
>
(),
element_count
);
break
;
case
element
:
:
Type_t
::
undefined
:
case
element
:
:
Type_t
::
dynamic
:
case
element
:
:
Type_t
::
bf16
:
case
element
:
:
Type_t
::
f16
:
ss
<<
"unsupported element type "
<<
type
<<
" op Convert"
;
throw
std
::
runtime_error
(
ss
.
str
());
}
...
...
@@ -559,9 +623,9 @@ private:
case
OP_TYPEID
:
:
Convolution
:
{
const
op
::
Convolution
*
c
=
static_cast
<
const
op
::
Convolution
*>
(
&
node
);
reference
::
convolution
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
convolution
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_output_shape
(
0
),
...
...
@@ -569,38 +633,26 @@ private:
c
->
get_window_dilation_strides
(),
c
->
get_padding_below
(),
c
->
get_padding_above
(),
c
->
get_data_dilation_strides
(),
0
,
1
,
1
,
0
,
0
,
1
,
false
);
c
->
get_data_dilation_strides
());
break
;
}
case
OP_TYPEID
:
:
ConvolutionBackpropFilters
:
{
const
op
::
ConvolutionBackpropFilters
*
c
=
static_cast
<
const
op
::
ConvolutionBackpropFilters
*>
(
&
node
);
reference
::
convolution
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
const
T
*>
(
args
[
1
]),
static_cast
<
T
*>
(
out
[
0
]),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_output_shape
(
0
),
c
->
get_window_movement_strides_backward
(),
c
->
get_window_dilation_strides_backward
(),
c
->
get_padding_below_backward
(),
c
->
get_padding_above_backward
(),
c
->
get_data_dilation_strides_backward
(),
1
,
0
,
0
,
1
,
1
,
0
,
false
);
reference
::
convolution_backprop_filter
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
// input
args
[
1
]
->
get_data_ptr
<
const
T
>
(),
// delta_convolution_output
out
[
0
]
->
get_data_ptr
<
T
>
(),
// delta_filter
c
->
get_input_shape
(
0
),
// input_shape
c
->
get_input_shape
(
1
),
// convolution_output_shape
c
->
get_filters_shape
(),
// filter_shape
c
->
get_window_dilation_strides_forward
(),
c
->
get_window_movement_strides_forward
(),
c
->
get_padding_below_forward
(),
c
->
compute_backward_in_pad_above
(),
c
->
get_data_dilation_strides_forward
());
break
;
}
case
OP_TYPEID
:
:
ConvolutionBackpropData
:
...
...
@@ -608,38 +660,31 @@ private:
// Note that args[1] and args[0] are switched here from the usual order.
const
op
::
ConvolutionBackpropData
*
c
=
static_cast
<
const
op
::
ConvolutionBackpropData
*>
(
&
node
);
reference
::
convolution
<
T
>
(
static_cast
<
const
T
*>
(
args
[
1
]),
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]),
node
.
get_input_shape
(
1
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
c
->
get_window_movement_strides_backward
(),
c
->
get_window_dilation_strides_backward
(),
c
->
get_padding_below_backward
(),
c
->
get_padding_above_backward
(),
c
->
get_data_dilation_strides_backward
(),
0
,
1
,
0
,
1
,
0
,
1
,
true
);
reference
::
convolution_backprop_in
<
T
>
(
args
[
1
]
->
get_data_ptr
<
const
T
>
(),
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
c
->
get_input_shape
(
1
),
c
->
get_input_shape
(
0
),
c
->
get_data_batch_shape
(),
c
->
get_data_dilation_strides_forward
(),
c
->
get_window_dilation_strides_forward
(),
c
->
compute_backward_delta_out_pad_below
(),
c
->
compute_backward_delta_out_pad_above
(),
c
->
get_window_movement_strides_forward
());
break
;
}
case
OP_TYPEID
:
:
Cos
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
cos
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Cosh
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
cosh
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Dequantize
:
...
...
@@ -649,20 +694,20 @@ private:
if
(
type
==
element
::
f32
)
{
reference
::
dequantize
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
float
*>
(
args
[
1
]
),
static_cast
<
const
T
*>
(
args
[
2
]
),
static_cast
<
float
*>
(
out
[
0
]
),
reference
::
dequantize
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
float
>
(
),
args
[
2
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
float
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
dequantize
->
get_axes
());
}
else
if
(
type
==
element
::
f64
)
{
reference
::
dequantize
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
double
*>
(
args
[
1
]
),
static_cast
<
const
T
*>
(
args
[
2
]
),
static_cast
<
double
*>
(
out
[
0
]
),
reference
::
dequantize
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
double
>
(
),
args
[
2
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
double
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
dequantize
->
get_axes
());
...
...
@@ -680,9 +725,9 @@ private:
{
const
op
::
Divide
*
divop
=
static_cast
<
const
op
::
Divide
*>
(
&
node
);
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
divide
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
divide
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
,
divop
->
is_pythondiv
());
break
;
...
...
@@ -691,13 +736,23 @@ private:
{
const
op
::
Dot
*
dot
=
static_cast
<
const
op
::
Dot
*>
(
&
node
);
gcpu
::
kernel
::
dot
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
const
T
*>
(
args
[
1
]),
static_cast
<
T
*>
(
out
[
0
]),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_output_shape
(
0
),
dot
->
get_reduction_axes_count
());
kernel
::
dot
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
args
[
1
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_output_shape
(
0
),
dot
->
get_reduction_axes_count
());
break
;
}
case
OP_TYPEID
:
:
DynReshape
:
{
throw
unsupported_op
(
"Unsupported op '"
+
node
.
description
()
+
"'"
);
break
;
}
case
OP_TYPEID
:
:
DynSlice
:
{
throw
unsupported_op
(
"Unsupported op '"
+
node
.
description
()
+
"'"
);
break
;
}
case
OP_TYPEID
:
:
EmbeddingLookup
:
...
...
@@ -708,33 +763,33 @@ private:
if
(
type
==
element
::
f32
)
{
reference
::
embedding
<
T
,
float
>
(
static_cast
<
const
float
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
embedding
<
T
,
float
>
(
args
[
0
]
->
get_data_ptr
<
const
float
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
,
embed
->
get_shape
());
}
else
if
(
type
==
element
::
f64
)
{
reference
::
embedding
<
T
,
double
>
(
static_cast
<
const
double
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
embedding
<
T
,
double
>
(
args
[
0
]
->
get_data_ptr
<
const
double
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
,
embed
->
get_shape
());
}
else
if
(
type
==
element
::
i32
)
{
reference
::
embedding
<
T
,
int
>
(
static_cast
<
const
int
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
,
embed
->
get_shape
());
reference
::
embedding
<
T
,
int
32_t
>
(
args
[
0
]
->
get_data_ptr
<
const
int
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
,
embed
->
get_shape
());
}
else
if
(
type
==
element
::
i64
)
{
reference
::
embedding
<
T
,
int64_t
>
(
static_cast
<
const
int64_t
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
embedding
<
T
,
int64_t
>
(
args
[
0
]
->
get_data_ptr
<
const
int64_t
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
,
embed
->
get_shape
());
}
...
...
@@ -748,24 +803,56 @@ private:
case
OP_TYPEID
:
:
Equal
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
equal
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
char
*>
(
out
[
0
]
),
reference
::
equal
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
char
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Erf
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
erf
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Exp
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
exp
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
#ifdef INTERPRETER_USE_HYBRID
case
OP_TYPEID
:
:
FunctionCall
:
{
auto
f
=
static_cast
<
const
runtime
::
hybrid
::
op
::
FunctionCall
*>
(
&
node
);
auto
backend
=
f
->
get_backend
();
auto
executable
=
f
->
get_executable
();
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
outputs
;
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
inputs
;
for
(
const
std
::
shared_ptr
<
HostTensor
>&
t
:
out
)
{
auto
backend_tensor
=
backend
->
create_tensor
(
t
->
get_element_type
(),
t
->
get_shape
(),
t
->
get_data_ptr
());
outputs
.
push_back
(
backend_tensor
);
}
for
(
const
std
::
shared_ptr
<
HostTensor
>&
t
:
args
)
{
auto
backend_tensor
=
backend
->
create_tensor
(
t
->
get_element_type
(),
t
->
get_shape
(),
t
->
get_data_ptr
());
inputs
.
push_back
(
backend_tensor
);
}
executable
->
call
(
outputs
,
inputs
);
break
;
}
#endif
case
OP_TYPEID
:
:
Floor
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
floor
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Gather
:
...
...
@@ -826,36 +913,36 @@ private:
case
OP_TYPEID
:
:
Greater
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
greater
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
char
*>
(
out
[
0
]
),
reference
::
greater
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
char
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
GreaterEq
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
greater_eq
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
char
*>
(
out
[
0
]
),
reference
::
greater_eq
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
char
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Less
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
less
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
char
*>
(
out
[
0
]
),
reference
::
less
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
char
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
LessEq
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
less_eq
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
char
*>
(
out
[
0
]
),
reference
::
less_eq
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
char
>
(
),
element_count
);
break
;
}
...
...
@@ -863,14 +950,14 @@ private:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
log
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
LRN
:
{
const
op
::
LRN
*
lrn
=
static_cast
<
const
op
::
LRN
*>
(
&
node
);
reference
::
lrn
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
lrn
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
lrn
->
get_alpha
(),
lrn
->
get_beta
(),
...
...
@@ -881,8 +968,8 @@ private:
case
OP_TYPEID
:
:
Max
:
{
const
op
::
Max
*
max
=
static_cast
<
const
op
::
Max
*>
(
&
node
);
reference
::
max
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
max
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
max
->
get_reduction_axes
());
...
...
@@ -891,9 +978,9 @@ private:
case
OP_TYPEID
:
:
Maximum
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
maximum
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
maximum
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
...
...
@@ -901,8 +988,8 @@ private:
{
const
op
::
MaxPool
*
max_pool
=
static_cast
<
const
op
::
MaxPool
*>
(
&
node
);
reference
::
max_pool
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
max_pool
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
max_pool
->
get_window_shape
(),
...
...
@@ -916,9 +1003,9 @@ private:
const
op
::
MaxPoolBackprop
*
max_pool_backprop
=
static_cast
<
const
op
::
MaxPoolBackprop
*>
(
&
node
);
reference
::
max_pool_backprop
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
max_pool_backprop
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
1
),
node
.
get_output_shape
(
0
),
max_pool_backprop
->
get_window_shape
(),
...
...
@@ -930,8 +1017,8 @@ private:
case
OP_TYPEID
:
:
Min
:
{
const
op
::
Min
*
min
=
static_cast
<
const
op
::
Min
*>
(
&
node
);
reference
::
min
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
min
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
min
->
get_reduction_axes
());
...
...
@@ -940,18 +1027,18 @@ private:
case
OP_TYPEID
:
:
Minimum
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
minimum
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
minimum
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Multiply
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
multiply
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
multiply
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
...
...
@@ -959,30 +1046,30 @@ private:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
negate
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Not
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
logical_not
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
NotEqual
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
not_equal
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
char
*>
(
out
[
0
]
),
reference
::
not_equal
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
char
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
OneHot
:
{
const
op
::
OneHot
*
oh
=
static_cast
<
const
op
::
OneHot
*>
(
&
node
);
reference
::
one_hot
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
one_hot
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
oh
->
get_one_hot_axis
());
...
...
@@ -991,46 +1078,46 @@ private:
case
OP_TYPEID
:
:
Or
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
logical_or
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
logical_or
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Parameter
:
break
;
case
OP_TYPEID
:
:
Passthrough
:
{
const
op
::
Passthrough
*
passthrough
=
static_cast
<
const
op
::
Passthrough
*>
(
&
node
);
throw
unsupported_op
{
"Unsupported operation language: "
+
passthrough
->
language
()};
}
case
OP_TYPEID
:
:
Pad
:
{
const
op
::
Pad
*
pad
=
static_cast
<
const
op
::
Pad
*>
(
&
node
);
reference
::
pad
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
node
.
get_inputs
().
a
t
(
0
).
get_shape
(),
node
.
get_output_shape
(
0
),
reference
::
pad
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
inpu
t
(
0
).
get_shape
(),
node
.
output
(
0
).
get_shape
(
),
pad
->
get_padding_below
(),
pad
->
get_padding_above
(),
pad
->
get_pad
ding_interior
());
pad
->
get_pad
_mode
());
break
;
}
case
OP_TYPEID
:
:
Passthrough
:
{
const
op
::
Passthrough
*
passthrough
=
static_cast
<
const
op
::
Passthrough
*>
(
&
node
);
throw
unsupported_op
{
"Unsupported operation language: "
+
passthrough
->
language
()};
}
case
OP_TYPEID
:
:
Power
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
power
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
power
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Product
:
{
const
op
::
Product
*
product
=
static_cast
<
const
op
::
Product
*>
(
&
node
);
reference
::
product
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
product
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
product
->
get_reduction_axes
());
...
...
@@ -1043,10 +1130,10 @@ private:
if
(
type
==
element
::
u8
)
{
reference
::
quantize
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
const
uint8_t
*>
(
args
[
2
]
),
static_cast
<
uint8_t
*>
(
out
[
0
]
),
reference
::
quantize
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
args
[
2
]
->
get_data_ptr
<
const
uint8_t
>
(
),
out
[
0
]
->
get_data_ptr
<
uint8_t
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
quantize
->
get_axes
(),
...
...
@@ -1054,10 +1141,10 @@ private:
}
else
if
(
type
==
element
::
i8
)
{
reference
::
quantize
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
const
int8_t
*>
(
args
[
2
]
),
static_cast
<
int8_t
*>
(
out
[
0
]
),
reference
::
quantize
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
args
[
2
]
->
get_data_ptr
<
const
int8_t
>
(
),
out
[
0
]
->
get_data_ptr
<
int8_t
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
quantize
->
get_axes
(),
...
...
@@ -1065,10 +1152,10 @@ private:
}
else
if
(
type
==
element
::
i32
)
{
reference
::
quantize
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
const
int32_t
*>
(
args
[
2
]
),
static_cast
<
int32_t
*>
(
out
[
0
]
),
reference
::
quantize
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
args
[
2
]
->
get_data_ptr
<
const
int32_t
>
(
),
out
[
0
]
->
get_data_ptr
<
int32_t
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
quantize
->
get_axes
(),
...
...
@@ -1083,40 +1170,168 @@ private:
break
;
}
case
OP_TYPEID
:
:
QuantizedConvolution
:
{
const
op
::
QuantizedConvolution
*
qc
=
static_cast
<
const
op
::
QuantizedConvolution
*>
(
&
node
);
auto
input_element_type
=
qc
->
get_input_element_type
(
0
);
auto
filter_element_type
=
qc
->
get_input_element_type
(
1
);
auto
output_element_type
=
qc
->
get_output_element_type
(
0
);
if
(
input_element_type
==
element
::
u8
&&
filter_element_type
==
element
::
i8
&&
output_element_type
==
element
::
i8
)
{
reference
::
convolution
<
uint8_t
,
int8_t
,
int8_t
,
int32_t
>
(
args
[
0
]
->
get_data_ptr
<
const
uint8_t
>
(),
args
[
1
]
->
get_data_ptr
<
const
int8_t
>
(),
out
[
0
]
->
get_data_ptr
<
int8_t
>
(),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_output_shape
(
0
),
qc
->
get_window_movement_strides
(),
qc
->
get_window_dilation_strides
(),
qc
->
get_padding_below
(),
qc
->
get_padding_above
(),
qc
->
get_data_dilation_strides
(),
args
[
2
]
->
get_data_ptr
<
const
float
>
(),
args
[
3
]
->
get_data_ptr
<
const
uint8_t
>
(),
args
[
4
]
->
get_data_ptr
<
const
float
>
(),
args
[
5
]
->
get_data_ptr
<
const
int8_t
>
(),
args
[
6
]
->
get_data_ptr
<
const
float
>
(),
args
[
7
]
->
get_data_ptr
<
const
int8_t
>
());
}
else
if
(
input_element_type
==
element
::
u8
&&
filter_element_type
==
element
::
u8
&&
output_element_type
==
element
::
u8
)
{
reference
::
convolution
<
uint8_t
,
uint8_t
,
uint8_t
,
int32_t
>
(
args
[
0
]
->
get_data_ptr
<
const
uint8_t
>
(),
args
[
1
]
->
get_data_ptr
<
const
uint8_t
>
(),
out
[
0
]
->
get_data_ptr
<
uint8_t
>
(),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_output_shape
(
0
),
qc
->
get_window_movement_strides
(),
qc
->
get_window_dilation_strides
(),
qc
->
get_padding_below
(),
qc
->
get_padding_above
(),
qc
->
get_data_dilation_strides
(),
args
[
2
]
->
get_data_ptr
<
const
float
>
(),
args
[
3
]
->
get_data_ptr
<
const
uint8_t
>
(),
args
[
4
]
->
get_data_ptr
<
const
float
>
(),
args
[
5
]
->
get_data_ptr
<
const
uint8_t
>
(),
args
[
6
]
->
get_data_ptr
<
const
float
>
(),
args
[
7
]
->
get_data_ptr
<
const
uint8_t
>
());
}
else
if
(
input_element_type
==
element
::
u8
&&
filter_element_type
==
element
::
i8
&&
output_element_type
==
element
::
i32
)
{
reference
::
convolution
<
uint8_t
,
int8_t
,
int32_t
,
int32_t
>
(
args
[
0
]
->
get_data_ptr
<
const
uint8_t
>
(),
args
[
1
]
->
get_data_ptr
<
const
int8_t
>
(),
out
[
0
]
->
get_data_ptr
<
int32_t
>
(),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_output_shape
(
0
),
qc
->
get_window_movement_strides
(),
qc
->
get_window_dilation_strides
(),
qc
->
get_padding_below
(),
qc
->
get_padding_above
(),
qc
->
get_data_dilation_strides
(),
args
[
2
]
->
get_data_ptr
<
const
float
>
(),
args
[
3
]
->
get_data_ptr
<
const
uint8_t
>
(),
args
[
4
]
->
get_data_ptr
<
const
float
>
(),
args
[
5
]
->
get_data_ptr
<
const
int8_t
>
(),
args
[
6
]
->
get_data_ptr
<
const
float
>
(),
args
[
7
]
->
get_data_ptr
<
const
int32_t
>
());
}
else
if
(
input_element_type
==
element
::
u8
&&
filter_element_type
==
element
::
u8
&&
output_element_type
==
element
::
i32
)
{
reference
::
convolution
<
uint8_t
,
uint8_t
,
int32_t
,
int32_t
>
(
args
[
0
]
->
get_data_ptr
<
const
uint8_t
>
(),
args
[
1
]
->
get_data_ptr
<
const
uint8_t
>
(),
out
[
0
]
->
get_data_ptr
<
int32_t
>
(),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_output_shape
(
0
),
qc
->
get_window_movement_strides
(),
qc
->
get_window_dilation_strides
(),
qc
->
get_padding_below
(),
qc
->
get_padding_above
(),
qc
->
get_data_dilation_strides
(),
args
[
2
]
->
get_data_ptr
<
const
float
>
(),
args
[
3
]
->
get_data_ptr
<
const
uint8_t
>
(),
args
[
4
]
->
get_data_ptr
<
const
float
>
(),
args
[
5
]
->
get_data_ptr
<
const
uint8_t
>
(),
args
[
6
]
->
get_data_ptr
<
const
float
>
(),
args
[
7
]
->
get_data_ptr
<
const
int32_t
>
());
}
else
{
std
::
stringstream
ss
;
ss
<<
"unsupported element type"
;
throw
std
::
runtime_error
(
ss
.
str
());
}
break
;
}
case
OP_TYPEID
:
:
QuantizedAvgPool
:
case
OP_TYPEID
:
:
QuantizedConvolutionBias
:
case
OP_TYPEID
:
:
QuantizedConvolutionBiasAdd
:
case
OP_TYPEID
:
:
QuantizedConvolutionBiasSignedAdd
:
case
OP_TYPEID
:
:
QuantizedConvolutionRelu
:
case
OP_TYPEID
:
:
QuantizedConvolution
:
case
OP_TYPEID
:
:
QuantizedMaxPool
:
case
OP_TYPEID
:
:
QuantizedDotBias
:
case
OP_TYPEID
:
:
QuantizedDot
:
{
throw
unsupported_op
(
"Unsupported op '"
+
node
.
description
()
+
"'."
);
throw
unsupported_op
(
"Unsupported op '"
+
node
.
description
()
+
"' in Interpreter back end."
);
}
case
OP_TYPEID
:
:
Recv
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
memSize
=
element_count
*
sizeof
(
T
);
const
auto
*
op
=
static_cast
<
const
ngraph
::
op
::
Recv
*>
(
&
node
);
int
src_id
=
op
->
get_src_id
();
reference
::
recv
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
node
.
get_input_element_type
(
0
).
get_type_enum
(),
element_count
,
src_id
);
memcpy
(
out
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
0
]
->
get_data_ptr
<
T
>
(),
memSize
);
break
;
}
case
OP_TYPEID
:
:
Range
:
{
throw
unsupported_op
(
"Unsupported op '"
+
node
.
description
()
+
"'"
);
break
;
}
case
OP_TYPEID
:
:
Relu
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
relu
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
ReluBackprop
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
relu_backprop
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
relu_backprop
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
ReplaceSlice
:
{
const
op
::
ReplaceSlice
*
slice
=
static_cast
<
const
op
::
ReplaceSlice
*>
(
&
node
);
reference
::
replace_slice
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
replace_slice
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
1
),
slice
->
get_lower_bounds
(),
slice
->
get_upper_bounds
(),
...
...
@@ -1127,26 +1342,26 @@ private:
case
OP_TYPEID
:
:
Reshape
:
{
const
op
::
Reshape
*
reshape
=
static_cast
<
const
op
::
Reshape
*>
(
&
node
);
gcpu
::
kernel
::
reshape
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
node
.
get_input_shape
(
0
),
reshape
->
get_input_order
(),
node
.
get_output_shape
(
0
));
kernel
::
reshape
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
reshape
->
get_input_order
(),
node
.
get_output_shape
(
0
));
break
;
}
case
OP_TYPEID
:
:
Result
:
{
const
op
::
Result
*
res
=
static_cast
<
const
op
::
Result
*>
(
&
node
);
reference
::
result
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
result
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
shape_size
(
res
->
get_shape
()));
break
;
}
case
OP_TYPEID
:
:
Reverse
:
{
const
op
::
Reverse
*
reverse
=
static_cast
<
const
op
::
Reverse
*>
(
&
node
);
reference
::
reverse
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
reverse
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
reverse
->
get_reversed_axes
());
...
...
@@ -1158,12 +1373,12 @@ private:
if
(
node
.
get_input_element_type
(
1
)
==
element
::
i32
)
{
reference
::
reverse_sequence
<
T
,
int32_t
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
reverse_sequence
<
T
,
int32_t
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
reverse
->
get_batch_axis
(),
reverse
->
get_sequence_axis
(),
static_cast
<
const
int32_t
*>
(
args
[
1
]
));
args
[
1
]
->
get_data_ptr
<
const
int32_t
>
(
));
}
else
{
...
...
@@ -1234,31 +1449,46 @@ private:
case
OP_TYPEID
:
:
Select
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
select
<
T
>
(
static_cast
<
const
char
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
const
T
*>
(
args
[
2
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
select
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
char
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
args
[
2
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Send
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
memSize
=
element_count
*
sizeof
(
T
);
const
auto
*
op
=
static_cast
<
const
ngraph
::
op
::
Send
*>
(
&
node
);
int
dest_id
=
op
->
get_dest_id
();
reference
::
send
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
node
.
get_input_element_type
(
0
).
get_type_enum
(),
element_count
,
dest_id
);
memcpy
(
out
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
0
]
->
get_data_ptr
<
T
>
(),
memSize
);
break
;
}
case
OP_TYPEID
:
:
ShapeOf
:
{
reference
::
shape_of
(
node
.
get_input_shape
(
0
),
static_cast
<
uint64_t
*>
(
out
[
0
]
));
reference
::
shape_of
(
node
.
get_input_shape
(
0
),
out
[
0
]
->
get_data_ptr
<
uint64_t
>
(
));
break
;
}
case
OP_TYPEID
:
:
Sigmoid
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
sigmoid
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
SigmoidBackprop
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
sigmoid_backprop
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
sigmoid_backprop
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
...
...
@@ -1266,28 +1496,28 @@ private:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
sign
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Sin
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
sin
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Sinh
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
sinh
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Slice
:
{
const
op
::
Slice
*
slice
=
static_cast
<
const
op
::
Slice
*>
(
&
node
);
reference
::
slice
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
slice
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
slice
->
get_lower_bounds
(),
slice
->
get_upper_bounds
(),
...
...
@@ -1298,8 +1528,8 @@ private:
case
OP_TYPEID
:
:
Softmax
:
{
const
op
::
Softmax
*
softmax
=
static_cast
<
const
op
::
Softmax
*>
(
&
node
);
reference
::
softmax
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
softmax
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_output_shape
(
0
),
softmax
->
get_axes
());
break
;
...
...
@@ -1308,7 +1538,7 @@ private:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
sqrt
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
StopGradient
:
{
throw
unsupported_op
(
"Unsupported op 'StopGradient'"
);
...
...
@@ -1316,17 +1546,17 @@ private:
case
OP_TYPEID
:
:
Subtract
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
subtract
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
subtract
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Sum
:
{
const
op
::
Sum
*
sum
=
static_cast
<
const
op
::
Sum
*>
(
&
node
);
reference
::
sum
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
sum
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
sum
->
get_reduction_axes
());
...
...
@@ -1336,14 +1566,14 @@ private:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
tan
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Tanh
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
tanh
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
TopK
:
...
...
@@ -1351,9 +1581,9 @@ private:
const
op
::
TopK
*
topk
=
static_cast
<
const
op
::
TopK
*>
(
&
node
);
if
(
node
.
get_output_element_type
(
0
)
==
element
::
i64
)
{
reference
::
topk
<
T
,
int64_t
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
int64_t
*>
(
out
[
0
]
),
static_cast
<
T
*>
(
out
[
1
]
),
reference
::
topk
<
T
,
int64_t
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
int64_t
>
(
),
out
[
1
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
topk
->
get_top_k_axis
(),
...
...
@@ -1362,9 +1592,9 @@ private:
}
else
if
(
node
.
get_output_element_type
(
0
)
==
element
::
i32
)
{
reference
::
topk
<
T
,
int32_t
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
int32_t
*>
(
out
[
0
]
),
static_cast
<
T
*>
(
out
[
1
]
),
reference
::
topk
<
T
,
int32_t
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
int32_t
>
(
),
out
[
1
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
topk
->
get_top_k_axis
(),
...
...
@@ -1377,7 +1607,12 @@ private:
}
break
;
}
default
:
throw
unsupported_op
(
"Unsupported op '"
+
node
.
description
()
+
"'"
);
case
OP_TYPEID
:
:
DynBroadcast
:
case
OP_TYPEID
:
:
Transpose
:
case
OP_TYPEID
:
:
DynPad
:
case
OP_TYPEID
:
:
Tile
:
case
OP_TYPEID
:
:
DynReplaceSlice
:
throw
unsupported_op
(
"Unsupported op '"
+
node
.
description
()
+
"'"
);
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic pop
#endif
...
...
src/ngraph/runtime/generic_cpu/kernel/broadcast.hpp
View file @
47342339
...
...
@@ -140,6 +140,91 @@ namespace ngraph
}
}
template
<
typename
T
>
void
broadcast_5d
(
const
T
*
in
,
T
*
out
,
const
Shape
&
in_shape
,
const
Shape
&
out_shape
,
const
AxisSet
&
broadcast_axes
)
{
size_t
index
[
5
];
size_t
*
out_index
=
0
;
for
(
size_t
i
=
0
;
i
<
5
;
i
++
)
{
if
(
broadcast_axes
.
count
(
i
)
==
0
)
{
out_index
=
&
index
[
i
];
break
;
}
}
for
(
index
[
0
]
=
0
;
index
[
0
]
<
out_shape
[
0
];
++
index
[
0
])
{
for
(
index
[
1
]
=
0
;
index
[
1
]
<
out_shape
[
1
];
++
index
[
1
])
{
for
(
index
[
2
]
=
0
;
index
[
2
]
<
out_shape
[
2
];
++
index
[
2
])
{
for
(
index
[
3
]
=
0
;
index
[
3
]
<
out_shape
[
3
];
++
index
[
3
])
{
for
(
index
[
4
]
=
0
;
index
[
4
]
<
out_shape
[
4
];
++
index
[
4
])
{
out
[
index
[
0
]
*
out_shape
[
1
]
*
out_shape
[
2
]
*
out_shape
[
3
]
*
out_shape
[
4
]
+
index
[
1
]
*
out_shape
[
2
]
*
out_shape
[
3
]
*
out_shape
[
4
]
+
index
[
2
]
*
out_shape
[
3
]
*
out_shape
[
4
]
+
index
[
3
]
*
out_shape
[
4
]
+
index
[
4
]]
=
in
[
*
out_index
];
}
}
}
}
}
}
template
<
typename
T
>
void
broadcast_6d
(
const
T
*
in
,
T
*
out
,
const
Shape
&
in_shape
,
const
Shape
&
out_shape
,
const
AxisSet
&
broadcast_axes
)
{
size_t
index
[
6
];
size_t
*
out_index
=
0
;
for
(
size_t
i
=
0
;
i
<
6
;
i
++
)
{
if
(
broadcast_axes
.
count
(
i
)
==
0
)
{
out_index
=
&
index
[
i
];
break
;
}
}
for
(
index
[
0
]
=
0
;
index
[
0
]
<
out_shape
[
0
];
++
index
[
0
])
{
for
(
index
[
1
]
=
0
;
index
[
1
]
<
out_shape
[
1
];
++
index
[
1
])
{
for
(
index
[
2
]
=
0
;
index
[
2
]
<
out_shape
[
2
];
++
index
[
2
])
{
for
(
index
[
3
]
=
0
;
index
[
3
]
<
out_shape
[
3
];
++
index
[
3
])
{
for
(
index
[
4
]
=
0
;
index
[
4
]
<
out_shape
[
4
];
++
index
[
4
])
{
for
(
index
[
5
]
=
0
;
index
[
5
]
<
out_shape
[
5
];
++
index
[
5
])
{
out
[
index
[
0
]
*
out_shape
[
1
]
*
out_shape
[
2
]
*
out_shape
[
3
]
*
out_shape
[
4
]
*
out_shape
[
5
]
+
index
[
1
]
*
out_shape
[
2
]
*
out_shape
[
3
]
*
out_shape
[
4
]
*
out_shape
[
5
]
+
index
[
2
]
*
out_shape
[
3
]
*
out_shape
[
4
]
*
out_shape
[
5
]
+
index
[
3
]
*
out_shape
[
4
]
*
out_shape
[
5
]
+
index
[
4
]
*
out_shape
[
5
]
+
index
[
5
]]
=
in
[
*
out_index
];
}
}
}
}
}
}
}
template
<
typename
T
>
void
broadcast
(
const
T
*
in
,
T
*
out
,
...
...
@@ -167,6 +252,16 @@ namespace ngraph
case
4
:
broadcast_4d
<
T
>
(
in
,
out
,
in_shape
,
out_shape
,
broadcast_axes
);
break
;
case
5
:
broadcast_5d
<
T
>
(
in
,
out
,
in_shape
,
out_shape
,
broadcast_axes
);
break
;
case
6
:
broadcast_6d
<
T
>
(
in
,
out
,
in_shape
,
out_shape
,
broadcast_axes
);
break
;
default
:
runtime
::
reference
::
broadcast
<
T
>
(
in
,
out
,
in_shape
,
out_shape
,
broadcast_axes
);
break
;
}
}
else
...
...
src/ngraph/runtime/generic_cpu/kernel/reshape.hpp
View file @
47342339
...
...
@@ -244,10 +244,7 @@ namespace ngraph
case
4
:
reshape_in4
<
T
>
(
in
,
out
,
in_shape
,
in_axis_order
,
out_shape
);
break
;
case
5
:
reshape_in5
<
T
>
(
in
,
out
,
in_shape
,
in_axis_order
,
out_shape
);
break
;
case
6
:
reshape_in6
<
T
>
(
in
,
out
,
in_shape
,
in_axis_order
,
out_shape
);
break
;
default
:
NGRAPH_INFO
<<
"reference::reshape"
;
reference
::
reshape
(
in
,
out
,
in_shape
,
in_axis_order
,
out_shape
);
break
;
default
:
reference
::
reshape
(
in
,
out
,
in_shape
,
in_axis_order
,
out_shape
);
break
;
}
}
}
...
...
src/ngraph/runtime/generic_cpu/kernel/result.hpp
deleted
100644 → 0
View file @
30527e80
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <algorithm>
#include <cmath>
#include <numeric>
#include <vector>
#include "ngraph/shape.hpp"
namespace
ngraph
{
namespace
runtime
{
namespace
gcpu
{
namespace
kernel
{
template
<
typename
T
>
void
result
(
const
T
*
arg
,
T
*
out
,
size_t
count
)
{
memcpy
(
out
,
arg
,
sizeof
(
T
)
*
count
);
}
}
}
}
}
src/ngraph/runtime/generic_cpu/node_wrapper.hpp
View file @
47342339
...
...
@@ -51,7 +51,7 @@ class ngraph::runtime::gcpu::NodeWrapper
public
:
NodeWrapper
(
const
std
::
shared_ptr
<
const
ngraph
::
Node
>&
node
);
const
Node
&
get_node
()
const
{
return
*
m_node
;
}
std
::
shared_ptr
<
const
Node
>
get_node
()
const
{
return
m_node
;
}
ngraph
::
runtime
::
gcpu
::
OP_TYPEID
get_typeid
()
const
{
return
m_typeid
;
}
private
:
std
::
shared_ptr
<
const
ngraph
::
Node
>
m_node
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment