Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
334a55fa
Unverified
Commit
334a55fa
authored
Jul 09, 2019
by
Scott Cyphers
Committed by
GitHub
Jul 09, 2019
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' into rearhart/plaidml
parents
5cfe1075
47342339
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
693 additions
and
401 deletions
+693
-401
external_mkldnn.cmake
cmake/external_mkldnn.cmake
+3
-3
mkldnn.patch
cmake/mkldnn.patch
+0
-13
test_requirements.txt
python/test_requirements.txt
+1
-0
allocator.cpp
src/ngraph/runtime/allocator.cpp
+3
-2
allocator.hpp
src/ngraph/runtime/allocator.hpp
+1
-1
cpu_backend.cpp
src/ngraph/runtime/cpu/cpu_backend.cpp
+1
-1
CMakeLists.txt
src/ngraph/runtime/generic_cpu/CMakeLists.txt
+4
-4
gcpu_backend.cpp
src/ngraph/runtime/generic_cpu/gcpu_backend.cpp
+2
-2
gcpu_executable.cpp
src/ngraph/runtime/generic_cpu/gcpu_executable.cpp
+47
-30
gcpu_executable.hpp
src/ngraph/runtime/generic_cpu/gcpu_executable.hpp
+534
-299
broadcast.hpp
src/ngraph/runtime/generic_cpu/kernel/broadcast.hpp
+95
-0
reshape.hpp
src/ngraph/runtime/generic_cpu/kernel/reshape.hpp
+1
-4
result.hpp
src/ngraph/runtime/generic_cpu/kernel/result.hpp
+0
-41
node_wrapper.hpp
src/ngraph/runtime/generic_cpu/node_wrapper.hpp
+1
-1
No files found.
cmake/external_mkldnn.cmake
View file @
334a55fa
...
@@ -18,10 +18,10 @@ include(ExternalProject)
...
@@ -18,10 +18,10 @@ include(ExternalProject)
# Includes blas 3.8.0 in mkldnn
# Includes blas 3.8.0 in mkldnn
set
(
NGRAPH_MKLDNN_SHORT_VERSION 0
)
set
(
NGRAPH_MKLDNN_SHORT_VERSION 0
)
set
(
NGRAPH_MKLDNN_FULL_VERSION 0.
19
.0.0
)
set
(
NGRAPH_MKLDNN_FULL_VERSION 0.
20
.0.0
)
set
(
NGRAPH_MKLDNN_VERSION
"v0.
19
"
)
set
(
NGRAPH_MKLDNN_VERSION
"v0.
20
"
)
set
(
NGRAPH_MKLDNN_SUB_VERSION
"2019.0.5.20190502"
)
set
(
NGRAPH_MKLDNN_SUB_VERSION
"2019.0.5.20190502"
)
set
(
NGRAPH_MKLDNN_GIT_TAG
"
027de76
"
)
set
(
NGRAPH_MKLDNN_GIT_TAG
"
v0.20
"
)
#------------------------------------------------------------------------------
#------------------------------------------------------------------------------
# Fetch and install MKL-DNN
# Fetch and install MKL-DNN
...
...
cmake/mkldnn.patch
View file @
334a55fa
...
@@ -28,16 +28,3 @@ index f10feb20..05f47961 100644
...
@@ -28,16 +28,3 @@ index f10feb20..05f47961 100644
set_property(TARGET ${LIB_NAME} PROPERTY PUBLIC_HEADER ${HEADERS})
set_property(TARGET ${LIB_NAME} PROPERTY PUBLIC_HEADER ${HEADERS})
target_include_directories(${LIB_NAME} PUBLIC
target_include_directories(${LIB_NAME} PUBLIC
diff --git a/src/cpu/jit_avx512_common_conv_kernel.cpp b/src/cpu/jit_avx512_common_conv_kernel.cpp
index 1bb98fa43..b8b54401f 100644
--- a/src/cpu/jit_avx512_common_conv_kernel.cpp
+++ b/src/cpu/jit_avx512_common_conv_kernel.cpp
@@ -3055,7 +3055,7 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32::bias_kernel_3d() {
void jit_avx512_common_conv_bwd_weights_kernel_f32
::compute_oh_loop_common()
{
- assert(jcp.harness == harness_mb_reduction);
+ assert(one_of(jcp.harness, harness_mb_reduction, harness_3d_reduction));
int b_pad = jcp.b_pad;
int t_pad = jcp.t_pad;
bool is_dilated = jcp.dilate_h != 0;
python/test_requirements.txt
View file @
334a55fa
pytest
pytest
tox
tox
pydocstyle==3.0.0
flake8
flake8
flake8-commas
flake8-commas
flake8-comprehensions
flake8-comprehensions
...
...
src/ngraph/runtime/allocator.cpp
View file @
334a55fa
...
@@ -49,7 +49,8 @@ public:
...
@@ -49,7 +49,8 @@ public:
}
}
};
};
std
::
unique_ptr
<
ngraph
::
runtime
::
Allocator
>
ngraph
::
runtime
::
create
_default_allocator
()
ngraph
::
runtime
::
Allocator
*
ngraph
::
runtime
::
get
_default_allocator
()
{
{
return
std
::
unique_ptr
<
DefaultAllocator
>
(
new
DefaultAllocator
());
static
std
::
unique_ptr
<
DefaultAllocator
>
allocator
(
new
DefaultAllocator
());
return
allocator
.
get
();
}
}
src/ngraph/runtime/allocator.hpp
View file @
334a55fa
...
@@ -30,7 +30,7 @@ namespace ngraph
...
@@ -30,7 +30,7 @@ namespace ngraph
class
DefaultAllocator
;
class
DefaultAllocator
;
/// \brief Create a default allocator that calls into system
/// \brief Create a default allocator that calls into system
/// allocation libraries
/// allocation libraries
std
::
unique_ptr
<
Allocator
>
create
_default_allocator
();
ngraph
::
runtime
::
Allocator
*
get
_default_allocator
();
}
}
}
}
...
...
src/ngraph/runtime/cpu/cpu_backend.cpp
View file @
334a55fa
...
@@ -185,7 +185,7 @@ runtime::Allocator* runtime::cpu::CPU_Backend::get_host_memory_allocator()
...
@@ -185,7 +185,7 @@ runtime::Allocator* runtime::cpu::CPU_Backend::get_host_memory_allocator()
{
{
if
(
!
m_allocator
)
if
(
!
m_allocator
)
{
{
m_allocator
=
create
_default_allocator
();
return
runtime
::
get
_default_allocator
();
}
}
return
m_allocator
.
get
();
return
m_allocator
.
get
();
}
}
...
...
src/ngraph/runtime/generic_cpu/CMakeLists.txt
View file @
334a55fa
...
@@ -15,10 +15,10 @@
...
@@ -15,10 +15,10 @@
# ******************************************************************************
# ******************************************************************************
if
(
NGRAPH_GENERIC_CPU_ENABLE
)
if
(
NGRAPH_GENERIC_CPU_ENABLE
)
find_package
(
OpenMP
)
#
find_package(OpenMP)
if
(
OPENMP_FOUND
)
#
if (OPENMP_FOUND)
add_compile_options
(
${
OpenMP_CXX_FLAGS
}
)
#
add_compile_options(${OpenMP_CXX_FLAGS})
endif
()
#
endif()
add_library
(
gcpu_backend SHARED gcpu_backend.cpp gcpu_executable.cpp node_wrapper.cpp
)
add_library
(
gcpu_backend SHARED gcpu_backend.cpp gcpu_executable.cpp node_wrapper.cpp
)
if
(
NGRAPH_LIB_VERSIONING_ENABLE
)
if
(
NGRAPH_LIB_VERSIONING_ENABLE
)
set_target_properties
(
gcpu_backend PROPERTIES
set_target_properties
(
gcpu_backend PROPERTIES
...
...
src/ngraph/runtime/generic_cpu/gcpu_backend.cpp
View file @
334a55fa
...
@@ -52,14 +52,14 @@ runtime::gcpu::GCPUBackend::GCPUBackend(const vector<string>& unsupported_op_nam
...
@@ -52,14 +52,14 @@ runtime::gcpu::GCPUBackend::GCPUBackend(const vector<string>& unsupported_op_nam
shared_ptr
<
runtime
::
Tensor
>
runtime
::
gcpu
::
GCPUBackend
::
create_tensor
(
const
element
::
Type
&
type
,
shared_ptr
<
runtime
::
Tensor
>
runtime
::
gcpu
::
GCPUBackend
::
create_tensor
(
const
element
::
Type
&
type
,
const
Shape
&
shape
)
const
Shape
&
shape
)
{
{
return
make_shared
<
runtime
::
HostTensor
>
(
type
,
shape
,
this
);
return
make_shared
<
runtime
::
HostTensor
>
(
type
,
shape
);
}
}
shared_ptr
<
runtime
::
Tensor
>
runtime
::
gcpu
::
GCPUBackend
::
create_tensor
(
const
element
::
Type
&
type
,
shared_ptr
<
runtime
::
Tensor
>
runtime
::
gcpu
::
GCPUBackend
::
create_tensor
(
const
element
::
Type
&
type
,
const
Shape
&
shape
,
const
Shape
&
shape
,
void
*
memory_pointer
)
void
*
memory_pointer
)
{
{
return
make_shared
<
runtime
::
HostTensor
>
(
type
,
shape
,
memory_pointer
,
this
);
return
make_shared
<
runtime
::
HostTensor
>
(
type
,
shape
,
memory_pointer
);
}
}
shared_ptr
<
runtime
::
Executable
>
shared_ptr
<
runtime
::
Executable
>
...
...
src/ngraph/runtime/generic_cpu/gcpu_executable.cpp
View file @
334a55fa
...
@@ -15,17 +15,22 @@
...
@@ -15,17 +15,22 @@
//*****************************************************************************
//*****************************************************************************
#include "ngraph/runtime/generic_cpu/gcpu_executable.hpp"
#include "ngraph/runtime/generic_cpu/gcpu_executable.hpp"
#include "ngraph/cpio.hpp"
#include "ngraph/descriptor/layout/dense_tensor_layout.hpp"
#include "ngraph/descriptor/layout/dense_tensor_layout.hpp"
#include "ngraph/except.hpp"
#include "ngraph/except.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/util/binary_elementwise_comparison.hpp"
#include "ngraph/op/util/binary_elementwise_comparison.hpp"
#include "ngraph/pass/assign_layout.hpp"
#include "ngraph/pass/assign_layout.hpp"
#include "ngraph/pass/core_fusion.hpp"
#include "ngraph/pass/fused_op_decomposition.hpp"
#include "ngraph/pass/implicit_broadcast_elimination.hpp"
#include "ngraph/pass/like_replacement.hpp"
#include "ngraph/pass/like_replacement.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp"
#include "ngraph/pass/memory_layout.hpp"
#include "ngraph/runtime/backend_manager.hpp"
#include "ngraph/runtime/backend_manager.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
#include "ngraph/util.hpp"
using
namespace
std
;
using
namespace
std
;
...
@@ -35,21 +40,35 @@ using descriptor::layout::DenseTensorLayout;
...
@@ -35,21 +40,35 @@ using descriptor::layout::DenseTensorLayout;
runtime
::
gcpu
::
GCPUExecutable
::
GCPUExecutable
(
const
shared_ptr
<
Function
>&
function
,
runtime
::
gcpu
::
GCPUExecutable
::
GCPUExecutable
(
const
shared_ptr
<
Function
>&
function
,
bool
enable_performance_collection
)
bool
enable_performance_collection
)
:
m_is_compiled
{
true
}
,
m_performance_counters_enabled
{
enable_performance_collection
}
{
{
{
m_function
=
clone_function
(
*
function
);
m_is_compiled
=
true
;
pass
::
Manager
pass_manager
;
pass
::
Manager
pass_manager
;
pass_manager
.
register_pass
<
pass
::
LikeReplacement
>
();
pass_manager
.
register_pass
<
pass
::
LikeReplacement
>
();
pass_manager
.
register_pass
<
pass
::
FusedOpDecomposition
>
();
pass_manager
.
register_pass
<
pass
::
ImplicitBroadcastElimination
>
();
pass_manager
.
register_pass
<
pass
::
AssignLayout
<
DenseTensorLayout
>>
();
pass_manager
.
register_pass
<
pass
::
AssignLayout
<
DenseTensorLayout
>>
();
pass_manager
.
register_pass
<
pass
::
Liveness
>
();
pass_manager
.
register_pass
<
pass
::
Liveness
>
();
pass_manager
.
run_passes
(
function
);
pass_manager
.
run_passes
(
m_
function
);
for
(
const
shared_ptr
<
Node
>&
node
:
function
->
get_ordered_ops
())
for
(
const
shared_ptr
<
Node
>&
node
:
m_
function
->
get_ordered_ops
())
{
{
m_wrapped_nodes
.
emplace_back
(
node
);
m_wrapped_nodes
.
emplace_back
(
node
);
}
}
set_parameters_and_results
(
*
m_function
);
}
runtime
::
gcpu
::
GCPUExecutable
::
GCPUExecutable
(
const
std
::
string
&
model_string
)
:
m_is_compiled
{
true
}
,
m_performance_counters_enabled
{
false
}
{
m_function
=
deserialize
(
model_string
);
for
(
const
shared_ptr
<
Node
>&
node
:
m_function
->
get_ordered_ops
())
{
m_wrapped_nodes
.
emplace_back
(
node
);
}
}
set_parameters_and_results
(
*
function
);
set_parameters_and_results
(
*
m_
function
);
}
}
bool
runtime
::
gcpu
::
GCPUExecutable
::
call
(
const
vector
<
shared_ptr
<
runtime
::
Tensor
>>&
outputs
,
bool
runtime
::
gcpu
::
GCPUExecutable
::
call
(
const
vector
<
shared_ptr
<
runtime
::
Tensor
>>&
outputs
,
...
@@ -82,7 +101,7 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
...
@@ -82,7 +101,7 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
{
{
for
(
size_t
i
=
0
;
i
<
param
->
get_output_size
();
++
i
)
for
(
size_t
i
=
0
;
i
<
param
->
get_output_size
();
++
i
)
{
{
descriptor
::
Tensor
*
tensor
=
param
->
get_output_tensor_ptr
(
i
).
get
();
descriptor
::
Tensor
*
tensor
=
&
param
->
output
(
i
).
get_tensor
();
tensor_map
.
insert
({
tensor
,
func_inputs
[
input_count
++
]});
tensor_map
.
insert
({
tensor
,
func_inputs
[
input_count
++
]});
}
}
}
}
...
@@ -95,14 +114,14 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
...
@@ -95,14 +114,14 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
{
{
throw
ngraph_error
(
"One of function's outputs isn't op::Result"
);
throw
ngraph_error
(
"One of function's outputs isn't op::Result"
);
}
}
descriptor
::
Tensor
*
tensor
=
output
->
get_output_tensor_ptr
(
0
).
get
();
descriptor
::
Tensor
*
tensor
=
&
output
->
output
(
0
).
get_tensor
();
tensor_map
.
insert
({
tensor
,
func_outputs
[
output_count
]});
tensor_map
.
insert
({
tensor
,
func_outputs
[
output_count
]});
}
}
// for each ordered op in the graph
// for each ordered op in the graph
for
(
const
NodeWrapper
&
wrapped
:
m_wrapped_nodes
)
for
(
const
NodeWrapper
&
wrapped
:
m_wrapped_nodes
)
{
{
const
Node
*
op
=
&
wrapped
.
get_node
();
auto
op
=
wrapped
.
get_node
();
auto
type_id
=
wrapped
.
get_typeid
();
auto
type_id
=
wrapped
.
get_typeid
();
if
(
type_id
==
OP_TYPEID
::
Parameter
)
if
(
type_id
==
OP_TYPEID
::
Parameter
)
{
{
...
@@ -111,9 +130,9 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
...
@@ -111,9 +130,9 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
// get op inputs from map
// get op inputs from map
vector
<
shared_ptr
<
HostTensor
>>
op_inputs
;
vector
<
shared_ptr
<
HostTensor
>>
op_inputs
;
for
(
const
descriptor
::
Input
&
input
:
op
->
get_
inputs
())
for
(
auto
input
:
op
->
inputs
())
{
{
descriptor
::
Tensor
*
tensor
=
input
.
get_output
().
get_tensor_ptr
().
get
();
descriptor
::
Tensor
*
tensor
=
&
input
.
get_tensor
();
op_inputs
.
push_back
(
tensor_map
.
at
(
tensor
));
op_inputs
.
push_back
(
tensor_map
.
at
(
tensor
));
}
}
...
@@ -121,14 +140,14 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
...
@@ -121,14 +140,14 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
vector
<
shared_ptr
<
HostTensor
>>
op_outputs
;
vector
<
shared_ptr
<
HostTensor
>>
op_outputs
;
for
(
size_t
i
=
0
;
i
<
op
->
get_output_size
();
++
i
)
for
(
size_t
i
=
0
;
i
<
op
->
get_output_size
();
++
i
)
{
{
descriptor
::
Tensor
*
tensor
=
op
->
get_output_tensor_ptr
(
i
).
get
();
descriptor
::
Tensor
*
tensor
=
&
op
->
output
(
i
).
get_tensor
();
shared_ptr
<
HostTensor
>
host_tensor
;
shared_ptr
<
HostTensor
>
host_tensor
;
auto
it
=
tensor_map
.
find
(
tensor
);
auto
it
=
tensor_map
.
find
(
tensor
);
if
(
it
==
tensor_map
.
end
())
if
(
it
==
tensor_map
.
end
())
{
{
const
Shape
&
shape
=
op
->
get_output_shape
(
i
);
const
Shape
&
shape
=
op
->
get_output_shape
(
i
);
const
element
::
Type
&
type
=
op
->
get_output_element_type
(
i
);
const
element
::
Type
&
type
=
op
->
get_output_element_type
(
i
);
string
name
=
op
->
get_output_tensor
(
i
).
get_name
();
string
name
=
op
->
output
(
i
).
get_tensor
(
).
get_name
();
host_tensor
=
make_shared
<
runtime
::
HostTensor
>
(
type
,
shape
,
name
);
host_tensor
=
make_shared
<
runtime
::
HostTensor
>
(
type
,
shape
,
name
);
tensor_map
.
insert
({
tensor
,
host_tensor
});
tensor_map
.
insert
({
tensor
,
host_tensor
});
}
}
...
@@ -177,7 +196,7 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
...
@@ -177,7 +196,7 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
}
}
if
(
m_nan_check_enabled
)
if
(
m_nan_check_enabled
)
{
{
perform_nan_check
(
op_outputs
,
op
);
perform_nan_check
(
op_outputs
,
op
.
get
()
);
}
}
}
}
...
@@ -186,19 +205,9 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
...
@@ -186,19 +205,9 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
void
runtime
::
gcpu
::
GCPUExecutable
::
generate_calls
(
const
element
::
Type
&
type
,
void
runtime
::
gcpu
::
GCPUExecutable
::
generate_calls
(
const
element
::
Type
&
type
,
const
NodeWrapper
&
op
,
const
NodeWrapper
&
op
,
const
vector
<
shared_ptr
<
HostTensor
>>&
out
puts
,
const
vector
<
shared_ptr
<
HostTensor
>>&
out
,
const
vector
<
shared_ptr
<
HostTensor
>>&
in
puts
)
const
vector
<
shared_ptr
<
HostTensor
>>&
in
)
{
{
vector
<
void
*>
out
;
vector
<
const
void
*>
in
;
for
(
auto
t
:
outputs
)
{
out
.
push_back
(
t
->
get_data_ptr
());
}
for
(
auto
t
:
inputs
)
{
in
.
push_back
(
t
->
get_data_ptr
());
}
stringstream
ss
;
stringstream
ss
;
switch
(
type
.
get_type_enum
())
switch
(
type
.
get_type_enum
())
{
{
...
@@ -216,7 +225,8 @@ void runtime::gcpu::GCPUExecutable::generate_calls(const element::Type& type,
...
@@ -216,7 +225,8 @@ void runtime::gcpu::GCPUExecutable::generate_calls(const element::Type& type,
case
element
:
:
Type_t
::
undefined
:
case
element
:
:
Type_t
::
undefined
:
case
element
:
:
Type_t
::
dynamic
:
case
element
:
:
Type_t
::
dynamic
:
case
element
:
:
Type_t
::
bf16
:
case
element
:
:
Type_t
::
bf16
:
ss
<<
"unsupported element type "
<<
type
<<
" op "
<<
op
.
get_node
().
get_name
();
case
element
:
:
Type_t
::
f16
:
ss
<<
"unsupported element type "
<<
type
<<
" op "
<<
op
.
get_node
()
->
get_name
();
throw
ngraph_error
(
ss
.
str
());
throw
ngraph_error
(
ss
.
str
());
}
}
}
}
...
@@ -229,11 +239,9 @@ void runtime::gcpu::GCPUExecutable::set_nan_check(bool enable)
...
@@ -229,11 +239,9 @@ void runtime::gcpu::GCPUExecutable::set_nan_check(bool enable)
vector
<
runtime
::
PerformanceCounter
>
runtime
::
gcpu
::
GCPUExecutable
::
get_performance_data
()
const
vector
<
runtime
::
PerformanceCounter
>
runtime
::
gcpu
::
GCPUExecutable
::
get_performance_data
()
const
{
{
vector
<
runtime
::
PerformanceCounter
>
rc
;
vector
<
runtime
::
PerformanceCounter
>
rc
;
for
(
const
pair
<
const
Node
*
,
stopwatch
>
p
:
m_timer_map
)
for
(
const
pair
<
shared_ptr
<
const
Node
>
,
stopwatch
>
p
:
m_timer_map
)
{
{
rc
.
emplace_back
(
p
.
first
->
get_name
().
c_str
(),
rc
.
emplace_back
(
p
.
first
,
p
.
second
.
get_total_microseconds
(),
p
.
second
.
get_call_count
());
p
.
second
.
get_total_microseconds
(),
p
.
second
.
get_call_count
());
}
}
return
rc
;
return
rc
;
}
}
...
@@ -286,3 +294,12 @@ void runtime::gcpu::GCPUExecutable::perform_nan_check(const vector<shared_ptr<Ho
...
@@ -286,3 +294,12 @@ void runtime::gcpu::GCPUExecutable::perform_nan_check(const vector<shared_ptr<Ho
arg_number
++
;
arg_number
++
;
}
}
}
}
void
runtime
::
gcpu
::
GCPUExecutable
::
save
(
ostream
&
out
)
{
cpio
::
Writer
writer
(
out
);
string
si
=
"INTERPRETER Save File 1.0"
;
writer
.
write
(
"save_info"
,
si
.
data
(),
si
.
size
());
string
model
=
serialize
(
m_function
,
0
);
writer
.
write
(
"model"
,
model
.
data
(),
model
.
size
());
}
src/ngraph/runtime/generic_cpu/gcpu_executable.hpp
View file @
334a55fa
...
@@ -17,24 +17,31 @@
...
@@ -17,24 +17,31 @@
#pragma once
#pragma once
#include <initializer_list>
#include <initializer_list>
#include <iostream>
#include <memory>
#include <memory>
#include <sstream>
#include <sstream>
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "ngraph/op/all.hpp"
#include "ngraph/op/all.hpp"
#include "ngraph/op/allreduce.hpp"
#include "ngraph/op/any.hpp"
#include "ngraph/op/any.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/broadcast_distributed.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/embedding_lookup.hpp"
#include "ngraph/op/embedding_lookup.hpp"
#include "ngraph/op/experimental/batch_mat_mul.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_pad.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/gather.hpp"
...
@@ -48,11 +55,14 @@
...
@@ -48,11 +55,14 @@
#include "ngraph/op/passthrough.hpp"
#include "ngraph/op/passthrough.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/quantize.hpp"
#include "ngraph/op/quantize.hpp"
#include "ngraph/op/quantized_convolution.hpp"
#include "ngraph/op/recv.hpp"
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/send.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/sum.hpp"
...
@@ -64,7 +74,6 @@
...
@@ -64,7 +74,6 @@
#include "ngraph/runtime/generic_cpu/kernel/reshape.hpp"
#include "ngraph/runtime/generic_cpu/kernel/reshape.hpp"
#include "ngraph/runtime/generic_cpu/node_wrapper.hpp"
#include "ngraph/runtime/generic_cpu/node_wrapper.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/interpreter/node_wrapper.hpp"
#include "ngraph/runtime/reference/abs.hpp"
#include "ngraph/runtime/reference/abs.hpp"
#include "ngraph/runtime/reference/acos.hpp"
#include "ngraph/runtime/reference/acos.hpp"
#include "ngraph/runtime/reference/add.hpp"
#include "ngraph/runtime/reference/add.hpp"
...
@@ -77,7 +86,9 @@
...
@@ -77,7 +86,9 @@
#include "ngraph/runtime/reference/asin.hpp"
#include "ngraph/runtime/reference/asin.hpp"
#include "ngraph/runtime/reference/atan.hpp"
#include "ngraph/runtime/reference/atan.hpp"
#include "ngraph/runtime/reference/avg_pool.hpp"
#include "ngraph/runtime/reference/avg_pool.hpp"
#include "ngraph/runtime/reference/batch_mat_mul.hpp"
#include "ngraph/runtime/reference/batch_norm.hpp"
#include "ngraph/runtime/reference/batch_norm.hpp"
#include "ngraph/runtime/reference/broadcast.hpp"
#include "ngraph/runtime/reference/broadcast_distributed.hpp"
#include "ngraph/runtime/reference/broadcast_distributed.hpp"
#include "ngraph/runtime/reference/ceiling.hpp"
#include "ngraph/runtime/reference/ceiling.hpp"
#include "ngraph/runtime/reference/concat.hpp"
#include "ngraph/runtime/reference/concat.hpp"
...
@@ -89,8 +100,10 @@
...
@@ -89,8 +100,10 @@
#include "ngraph/runtime/reference/cosh.hpp"
#include "ngraph/runtime/reference/cosh.hpp"
#include "ngraph/runtime/reference/dequantize.hpp"
#include "ngraph/runtime/reference/dequantize.hpp"
#include "ngraph/runtime/reference/divide.hpp"
#include "ngraph/runtime/reference/divide.hpp"
#include "ngraph/runtime/reference/dot.hpp"
#include "ngraph/runtime/reference/embedding_lookup.hpp"
#include "ngraph/runtime/reference/embedding_lookup.hpp"
#include "ngraph/runtime/reference/equal.hpp"
#include "ngraph/runtime/reference/equal.hpp"
#include "ngraph/runtime/reference/erf.hpp"
#include "ngraph/runtime/reference/exp.hpp"
#include "ngraph/runtime/reference/exp.hpp"
#include "ngraph/runtime/reference/floor.hpp"
#include "ngraph/runtime/reference/floor.hpp"
#include "ngraph/runtime/reference/gather.hpp"
#include "ngraph/runtime/reference/gather.hpp"
...
@@ -117,14 +130,17 @@
...
@@ -117,14 +130,17 @@
#include "ngraph/runtime/reference/power.hpp"
#include "ngraph/runtime/reference/power.hpp"
#include "ngraph/runtime/reference/product.hpp"
#include "ngraph/runtime/reference/product.hpp"
#include "ngraph/runtime/reference/quantize.hpp"
#include "ngraph/runtime/reference/quantize.hpp"
#include "ngraph/runtime/reference/recv.hpp"
#include "ngraph/runtime/reference/relu.hpp"
#include "ngraph/runtime/reference/relu.hpp"
#include "ngraph/runtime/reference/replace_slice.hpp"
#include "ngraph/runtime/reference/replace_slice.hpp"
#include "ngraph/runtime/reference/reshape.hpp"
#include "ngraph/runtime/reference/result.hpp"
#include "ngraph/runtime/reference/result.hpp"
#include "ngraph/runtime/reference/reverse.hpp"
#include "ngraph/runtime/reference/reverse.hpp"
#include "ngraph/runtime/reference/reverse_sequence.hpp"
#include "ngraph/runtime/reference/reverse_sequence.hpp"
#include "ngraph/runtime/reference/scatter_add.hpp"
#include "ngraph/runtime/reference/scatter_add.hpp"
#include "ngraph/runtime/reference/scatter_nd_add.hpp"
#include "ngraph/runtime/reference/scatter_nd_add.hpp"
#include "ngraph/runtime/reference/select.hpp"
#include "ngraph/runtime/reference/select.hpp"
#include "ngraph/runtime/reference/send.hpp"
#include "ngraph/runtime/reference/shape_of.hpp"
#include "ngraph/runtime/reference/shape_of.hpp"
#include "ngraph/runtime/reference/sigmoid.hpp"
#include "ngraph/runtime/reference/sigmoid.hpp"
#include "ngraph/runtime/reference/sign.hpp"
#include "ngraph/runtime/reference/sign.hpp"
...
@@ -134,6 +150,7 @@
...
@@ -134,6 +150,7 @@
#include "ngraph/runtime/reference/softmax.hpp"
#include "ngraph/runtime/reference/softmax.hpp"
#include "ngraph/runtime/reference/sqrt.hpp"
#include "ngraph/runtime/reference/sqrt.hpp"
#include "ngraph/runtime/reference/subtract.hpp"
#include "ngraph/runtime/reference/subtract.hpp"
#include "ngraph/runtime/reference/sum.hpp"
#include "ngraph/runtime/reference/tan.hpp"
#include "ngraph/runtime/reference/tan.hpp"
#include "ngraph/runtime/reference/tanh.hpp"
#include "ngraph/runtime/reference/tanh.hpp"
#include "ngraph/runtime/reference/topk.hpp"
#include "ngraph/runtime/reference/topk.hpp"
...
@@ -154,6 +171,8 @@ namespace ngraph
...
@@ -154,6 +171,8 @@ namespace ngraph
class
ngraph
::
runtime
::
gcpu
::
GCPUExecutable
:
public
Executable
class
ngraph
::
runtime
::
gcpu
::
GCPUExecutable
:
public
Executable
{
{
friend
class
GCPUBackend
;
public
:
public
:
GCPUExecutable
(
const
std
::
shared_ptr
<
Function
>&
function
,
GCPUExecutable
(
const
std
::
shared_ptr
<
Function
>&
function
,
bool
enable_performance_collection
=
false
);
bool
enable_performance_collection
=
false
);
...
@@ -161,20 +180,25 @@ public:
...
@@ -161,20 +180,25 @@ public:
bool
call
(
const
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>&
outputs
,
bool
call
(
const
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>&
outputs
,
const
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>&
intputs
)
override
;
const
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>&
intputs
)
override
;
virtual
void
save
(
std
::
ostream
&
output_stream
)
override
;
void
set_nan_check
(
bool
enable
);
void
set_nan_check
(
bool
enable
);
std
::
vector
<
PerformanceCounter
>
get_performance_data
()
const
override
;
std
::
vector
<
PerformanceCounter
>
get_performance_data
()
const
override
;
private
:
private
:
GCPUExecutable
(
const
std
::
string
&
model_string
);
int
get_alignment
()
const
{
return
64
;
}
bool
m_is_compiled
=
false
;
bool
m_is_compiled
=
false
;
bool
m_nan_check_enabled
=
false
;
bool
m_nan_check_enabled
=
false
;
bool
m_performance_counters_enabled
=
false
;
bool
m_performance_counters_enabled
=
false
;
std
::
unordered_map
<
const
Node
*
,
stopwatch
>
m_timer_map
;
std
::
shared_ptr
<
Function
>
m_function
;
std
::
unordered_map
<
std
::
shared_ptr
<
const
Node
>
,
stopwatch
>
m_timer_map
;
std
::
vector
<
NodeWrapper
>
m_wrapped_nodes
;
std
::
vector
<
NodeWrapper
>
m_wrapped_nodes
;
std
::
unordered_map
<
const
Node
*
,
std
::
shared_ptr
<
RNGState
>>
m_states
;
std
::
unordered_map
<
const
Node
*
,
std
::
shared_ptr
<
RNGState
>>
m_states
;
std
::
set
<
std
::
string
>
m_unsupported_op_name_list
;
std
::
set
<
std
::
string
>
m_unsupported_op_name_list
;
int
get_alignment
()
const
{
return
64
;
}
static
void
perform_nan_check
(
const
std
::
vector
<
std
::
shared_ptr
<
HostTensor
>>&
,
static
void
perform_nan_check
(
const
std
::
vector
<
std
::
shared_ptr
<
HostTensor
>>&
,
const
Node
*
op
=
nullptr
);
const
Node
*
op
=
nullptr
);
...
@@ -185,11 +209,10 @@ private:
...
@@ -185,11 +209,10 @@ private:
template
<
typename
T
>
template
<
typename
T
>
void
op_engine
(
const
NodeWrapper
&
node_wrapper
,
void
op_engine
(
const
NodeWrapper
&
node_wrapper
,
const
std
::
vector
<
void
*
>&
out
,
const
std
::
vector
<
std
::
shared_ptr
<
HostTensor
>
>&
out
,
const
std
::
vector
<
const
void
*
>&
args
)
const
std
::
vector
<
std
::
shared_ptr
<
HostTensor
>
>&
args
)
{
{
const
Node
&
node
=
node_wrapper
.
get_node
();
const
Node
&
node
=
*
node_wrapper
.
get_node
();
std
::
string
node_op
=
node
.
description
();
// We want to check that every OP_TYPEID enumeration is included in the list.
// We want to check that every OP_TYPEID enumeration is included in the list.
// These GCC flags enable compile-time checking so that if an enumeration
// These GCC flags enable compile-time checking so that if an enumeration
...
@@ -206,30 +229,30 @@ private:
...
@@ -206,30 +229,30 @@ private:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
abs
<
T
>
(
reference
::
abs
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Acos
:
case
OP_TYPEID
:
:
Acos
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
acos
<
T
>
(
reference
::
acos
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Add
:
case
OP_TYPEID
:
:
Add
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
add
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
add
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
All
:
case
OP_TYPEID
:
:
All
:
{
{
const
op
::
All
*
all
=
static_cast
<
const
op
::
All
*>
(
&
node
);
const
op
::
All
*
all
=
static_cast
<
const
op
::
All
*>
(
&
node
);
reference
::
all
(
static_cast
<
const
char
*>
(
args
[
0
]
),
reference
::
all
(
args
[
0
]
->
get_data_ptr
<
const
char
>
(
),
static_cast
<
char
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
char
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
all
->
get_reduction_axes
());
all
->
get_reduction_axes
());
...
@@ -237,26 +260,29 @@ private:
...
@@ -237,26 +260,29 @@ private:
}
}
case
OP_TYPEID
:
:
AllReduce
:
case
OP_TYPEID
:
:
AllReduce
:
{
{
reference
::
allreduce
<
T
>
(
static_cast
<
T
*>
(
const_cast
<
void
*>
(
args
[
0
])),
const
ngraph
::
op
::
AllReduce
*
allreduce
=
static_cast
<
T
*>
(
out
[
0
]),
static_cast
<
const
ngraph
::
op
::
AllReduce
*>
(
&
node
);
node
.
get_input_element_type
(
0
),
reference
::
allreduce
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
node
.
get_input_element_type
(
0
).
get_type_enum
(),
allreduce
->
get_reduce_type
(),
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
))));
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
))));
break
;
break
;
}
}
case
OP_TYPEID
:
:
And
:
case
OP_TYPEID
:
:
And
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
logical_and
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
logical_and
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Any
:
case
OP_TYPEID
:
:
Any
:
{
{
const
op
::
Any
*
any
=
static_cast
<
const
op
::
Any
*>
(
&
node
);
const
op
::
Any
*
any
=
static_cast
<
const
op
::
Any
*>
(
&
node
);
reference
::
any
(
static_cast
<
const
char
*>
(
args
[
0
]
),
reference
::
any
(
args
[
0
]
->
get_data_ptr
<
const
char
>
(
),
static_cast
<
char
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
char
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
any
->
get_reduction_axes
());
any
->
get_reduction_axes
());
...
@@ -268,16 +294,16 @@ private:
...
@@ -268,16 +294,16 @@ private:
auto
element_type
=
node
.
get_output_element_type
(
0
);
auto
element_type
=
node
.
get_output_element_type
(
0
);
if
(
element_type
==
element
::
i64
)
if
(
element_type
==
element
::
i64
)
{
{
reference
::
argmin
<
T
,
int64_t
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
argmin
<
T
,
int64_t
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
int64_t
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
int64_t
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
argmin
->
get_reduction_axis
());
argmin
->
get_reduction_axis
());
}
}
else
if
(
element_type
==
element
::
i32
)
else
if
(
element_type
==
element
::
i32
)
{
{
reference
::
argmin
<
T
,
int32_t
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
argmin
<
T
,
int32_t
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
int32_t
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
int32_t
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
argmin
->
get_reduction_axis
());
argmin
->
get_reduction_axis
());
...
@@ -294,16 +320,16 @@ private:
...
@@ -294,16 +320,16 @@ private:
auto
element_type
=
node
.
get_output_element_type
(
0
);
auto
element_type
=
node
.
get_output_element_type
(
0
);
if
(
element_type
==
element
::
i64
)
if
(
element_type
==
element
::
i64
)
{
{
reference
::
argmax
<
T
,
int64_t
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
argmax
<
T
,
int64_t
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
int64_t
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
int64_t
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
argmax
->
get_reduction_axis
());
argmax
->
get_reduction_axis
());
}
}
else
if
(
element_type
==
element
::
i32
)
else
if
(
element_type
==
element
::
i32
)
{
{
reference
::
argmax
<
T
,
int32_t
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
argmax
<
T
,
int32_t
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
int32_t
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
int32_t
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
argmax
->
get_reduction_axis
());
argmax
->
get_reduction_axis
());
...
@@ -318,22 +344,22 @@ private:
...
@@ -318,22 +344,22 @@ private:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
asin
<
T
>
(
reference
::
asin
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Atan
:
case
OP_TYPEID
:
:
Atan
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
atan
<
T
>
(
reference
::
atan
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
AvgPool
:
case
OP_TYPEID
:
:
AvgPool
:
{
{
const
op
::
AvgPool
*
avg_pool
=
static_cast
<
const
op
::
AvgPool
*>
(
&
node
);
const
op
::
AvgPool
*
avg_pool
=
static_cast
<
const
op
::
AvgPool
*>
(
&
node
);
reference
::
avg_pool
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
avg_pool
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
avg_pool
->
get_window_shape
(),
avg_pool
->
get_window_shape
(),
...
@@ -345,18 +371,30 @@ private:
...
@@ -345,18 +371,30 @@ private:
}
}
case
OP_TYPEID
:
:
GenerateMask
:
case
OP_TYPEID
:
:
GenerateMask
:
{
{
bool
use_seed
=
static_cast
<
bool
>
(
args
[
2
]
->
get_data_ptr
<
const
int32_t
>
()[
0
]);
if
(
m_states
.
count
(
&
node
)
==
0
)
if
(
m_states
.
count
(
&
node
)
==
0
)
{
{
const
op
::
GenerateMask
*
gm
=
static_cast
<
const
op
::
GenerateMask
*>
(
&
node
);
const
op
::
GenerateMask
*
gm
=
static_cast
<
const
op
::
GenerateMask
*>
(
&
node
);
auto
seed
=
use_seed
?
gm
->
get_seed
()
:
0
;
m_states
[
&
node
]
=
std
::
unique_ptr
<
ngraph
::
RNGState
>
(
m_states
[
&
node
]
=
std
::
unique_ptr
<
ngraph
::
RNGState
>
(
ngraph
::
RNGState
::
create_rng_state
(
gm
->
get_seed
()
,
gm
->
get_probability
()));
ngraph
::
RNGState
::
create_rng_state
(
seed
,
gm
->
get_probability
()));
}
}
bool
training
=
static_cast
<
bool
>
(
static_cast
<
const
T
*>
(
args
[
0
]
)[
0
]);
bool
training
=
static_cast
<
bool
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
)[
0
]);
auto
state
=
m_states
.
at
(
&
node
).
get
();
auto
state
=
m_states
.
at
(
&
node
).
get
();
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
if
(
!
use_seed
)
{
reference
::
generate_mask
<
T
>
(
reference
::
generate_mask
<
T
>
(
reinterpret_cast
<
T
*>
(
out
[
0
]),
element_count
,
state
,
training
);
out
[
0
]
->
get_data_ptr
<
T
>
(),
element_count
,
state
,
training
);
}
else
{
uint64_t
seed
=
static_cast
<
uint64_t
>
(
args
[
3
]
->
get_data_ptr
<
const
T
>
()[
0
]);
double
prob
=
static_cast
<
double
>
(
args
[
4
]
->
get_data_ptr
<
const
T
>
()[
0
]);
reference
::
generate_mask_no_state
<
T
>
(
out
[
0
]
->
get_data_ptr
<
T
>
(),
element_count
,
training
,
seed
,
prob
);
}
break
;
break
;
}
}
case
OP_TYPEID
:
:
GetOutputElement
:
case
OP_TYPEID
:
:
GetOutputElement
:
...
@@ -366,20 +404,31 @@ private:
...
@@ -366,20 +404,31 @@ private:
size_t
n
=
get_output_element
->
get_n
();
size_t
n
=
get_output_element
->
get_n
();
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
num_bytes
=
element_count
*
node
.
get_output_element_type
(
0
).
size
();
size_t
num_bytes
=
element_count
*
node
.
get_output_element_type
(
0
).
size
();
std
::
memcpy
(
static_cast
<
T
*>
(
out
[
0
]),
args
[
n
],
num_bytes
);
std
::
memcpy
(
out
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
n
]
->
get_data_ptr
<
T
>
(),
num_bytes
);
break
;
}
case
OP_TYPEID
:
:
BatchMatMul
:
{
reference
::
batch_mat_mul
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
args
[
1
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_output_shape
(
0
));
break
;
break
;
}
}
case
OP_TYPEID
:
:
BatchNormTraining
:
case
OP_TYPEID
:
:
BatchNormTraining
:
{
{
const
ngraph
::
op
::
BatchNormTraining
*
bn
=
const
ngraph
::
op
::
BatchNormTraining
*
bn
=
static_cast
<
const
ngraph
::
op
::
BatchNormTraining
*>
(
&
node
);
static_cast
<
const
ngraph
::
op
::
BatchNormTraining
*>
(
&
node
);
reference
::
batch_norm_training
<
T
>
(
bn
->
get_eps_value
(),
reference
::
batch_norm_training
<
T
>
(
bn
->
get_eps_value
(),
static_cast
<
const
T
*>
(
args
[
0
]
),
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
2
]
),
args
[
2
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
static_cast
<
T
*>
(
out
[
1
]
),
out
[
1
]
->
get_data_ptr
<
T
>
(
),
static_cast
<
T
*>
(
out
[
2
]
),
out
[
2
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
2
));
node
.
get_input_shape
(
2
));
break
;
break
;
}
}
...
@@ -388,12 +437,12 @@ private:
...
@@ -388,12 +437,12 @@ private:
const
ngraph
::
op
::
BatchNormInference
*
bn
=
const
ngraph
::
op
::
BatchNormInference
*
bn
=
static_cast
<
const
ngraph
::
op
::
BatchNormInference
*>
(
&
node
);
static_cast
<
const
ngraph
::
op
::
BatchNormInference
*>
(
&
node
);
reference
::
batch_norm_inference
<
T
>
(
bn
->
get_eps_value
(),
reference
::
batch_norm_inference
<
T
>
(
bn
->
get_eps_value
(),
static_cast
<
const
T
*>
(
args
[
0
]
),
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
2
]
),
args
[
2
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
3
]
),
args
[
3
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
4
]
),
args
[
4
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
2
));
node
.
get_input_shape
(
2
));
break
;
break
;
}
}
...
@@ -402,23 +451,23 @@ private:
...
@@ -402,23 +451,23 @@ private:
const
ngraph
::
op
::
BatchNormTrainingBackprop
*
bn_bprop
=
const
ngraph
::
op
::
BatchNormTrainingBackprop
*
bn_bprop
=
static_cast
<
const
ngraph
::
op
::
BatchNormTrainingBackprop
*>
(
&
node
);
static_cast
<
const
ngraph
::
op
::
BatchNormTrainingBackprop
*>
(
&
node
);
reference
::
batch_norm_backprop
(
bn_bprop
->
get_eps_value
(),
reference
::
batch_norm_backprop
(
bn_bprop
->
get_eps_value
(),
static_cast
<
const
T
*>
(
args
[
0
]
),
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
2
]
),
args
[
2
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
3
]
),
args
[
3
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
4
]
),
args
[
4
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
5
]
),
args
[
5
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
static_cast
<
T
*>
(
out
[
1
]
),
out
[
1
]
->
get_data_ptr
<
T
>
(
),
static_cast
<
T
*>
(
out
[
2
]
),
out
[
2
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
2
));
node
.
get_input_shape
(
2
));
break
;
break
;
}
}
case
OP_TYPEID
:
:
AvgPoolBackprop
:
case
OP_TYPEID
:
:
AvgPoolBackprop
:
{
{
const
op
::
AvgPoolBackprop
*
apb
=
static_cast
<
const
op
::
AvgPoolBackprop
*>
(
&
node
);
const
op
::
AvgPoolBackprop
*
apb
=
static_cast
<
const
op
::
AvgPoolBackprop
*>
(
&
node
);
reference
::
avg_pool_backprop
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
avg_pool_backprop
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
apb
->
get_window_shape
(),
apb
->
get_window_shape
(),
...
@@ -434,8 +483,8 @@ private:
...
@@ -434,8 +483,8 @@ private:
Shape
in_shape
=
node
.
get_input_shape
(
0
);
Shape
in_shape
=
node
.
get_input_shape
(
0
);
Shape
out_shape
=
node
.
get_output_shape
(
0
);
Shape
out_shape
=
node
.
get_output_shape
(
0
);
AxisSet
broadcast_axes
=
broadcast
->
get_broadcast_axes
();
AxisSet
broadcast_axes
=
broadcast
->
get_broadcast_axes
();
gcpu
::
kernel
::
broadcast
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
kernel
::
broadcast
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
in_shape
,
in_shape
,
out_shape
,
out_shape
,
broadcast_axes
);
broadcast_axes
);
...
@@ -443,23 +492,28 @@ private:
...
@@ -443,23 +492,28 @@ private:
}
}
case
OP_TYPEID
:
:
BroadcastDistributed
:
case
OP_TYPEID
:
:
BroadcastDistributed
:
{
{
int
rank_ID
=
get_distributed_interface
()
->
get_rank
();
const
ngraph
::
op
::
BroadcastDistributed
*
broadcast
=
if
(
rank_ID
==
0
)
static_cast
<
const
ngraph
::
op
::
BroadcastDistributed
*>
(
&
node
);
int
rank_ID
;
rank_ID
=
get_distributed_interface
()
->
get_rank
();
int
root_id
=
broadcast
->
get_root_id
();
if
(
rank_ID
==
root_id
)
{
{
reference
::
broadcastdistributed
<
T
>
(
reference
::
broadcastdistributed
<
T
>
(
static_cast
<
T
*>
(
args
[
0
]
),
args
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_element_type
(
0
),
node
.
get_input_element_type
(
0
)
.
get_type_enum
()
,
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
)))
);
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
)))
,
auto
memSize
=
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
)))
*
root_id
);
sizeof
(
node
.
get_input_element_type
(
0
)
);
auto
memSize
=
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
)))
*
sizeof
(
T
);
memcpy
(
out
[
0
]
,
args
[
0
]
,
memSize
);
memcpy
(
out
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
0
]
->
get_data_ptr
<
T
>
()
,
memSize
);
}
}
else
else
{
{
reference
::
broadcastdistributed
<
T
>
(
reference
::
broadcastdistributed
<
T
>
(
static_cast
<
T
*>
(
out
[
0
]),
out
[
0
]
->
get_data_ptr
<
T
>
(),
node
.
get_input_element_type
(
0
),
node
.
get_input_element_type
(
0
).
get_type_enum
(),
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
))));
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
))),
root_id
);
}
}
break
;
break
;
}
}
...
@@ -468,7 +522,7 @@ private:
...
@@ -468,7 +522,7 @@ private:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
ceiling
<
T
>
(
reference
::
ceiling
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Concat
:
case
OP_TYPEID
:
:
Concat
:
...
@@ -478,11 +532,11 @@ private:
...
@@ -478,11 +532,11 @@ private:
std
::
vector
<
Shape
>
in_shapes
;
std
::
vector
<
Shape
>
in_shapes
;
for
(
size_t
i
=
0
;
i
<
node
.
get_input_size
();
i
++
)
for
(
size_t
i
=
0
;
i
<
node
.
get_input_size
();
i
++
)
{
{
in_args
.
push_back
(
static_cast
<
const
T
*>
(
args
[
i
]
));
in_args
.
push_back
(
args
[
i
]
->
get_data_ptr
<
const
T
>
(
));
in_shapes
.
push_back
(
node
.
get_input_shape
(
i
));
in_shapes
.
push_back
(
node
.
get_input_shape
(
i
));
}
}
reference
::
concat
<
T
>
(
in_args
,
reference
::
concat
<
T
>
(
in_args
,
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
in_shapes
,
in_shapes
,
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
concat
->
get_concatenation_axis
());
concat
->
get_concatenation_axis
());
...
@@ -492,7 +546,7 @@ private:
...
@@ -492,7 +546,7 @@ private:
{
{
const
op
::
Constant
*
c
=
static_cast
<
const
op
::
Constant
*>
(
&
node
);
const
op
::
Constant
*
c
=
static_cast
<
const
op
::
Constant
*>
(
&
node
);
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
constant
<
T
>
(
c
->
get_data_ptr
<
T
>
(),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
reference
::
constant
<
T
>
(
c
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
ScalarConstantLike
:
break
;
case
OP_TYPEID
:
:
ScalarConstantLike
:
break
;
...
@@ -505,52 +559,62 @@ private:
...
@@ -505,52 +559,62 @@ private:
switch
(
type
.
get_type_enum
())
switch
(
type
.
get_type_enum
())
{
{
case
element
:
:
Type_t
::
boolean
:
case
element
:
:
Type_t
::
boolean
:
reference
::
convert
<
T
>
(
reference
::
convert
_to_bool
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
char
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
char
>
(
),
element_count
);
break
;
break
;
case
element
:
:
Type_t
::
f32
:
case
element
:
:
Type_t
::
f32
:
reference
::
convert
<
T
>
(
reference
::
convert
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
float
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
float
>
(
),
element_count
);
break
;
break
;
case
element
:
:
Type_t
::
f64
:
case
element
:
:
Type_t
::
f64
:
reference
::
convert
<
T
>
(
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
double
*>
(
out
[
0
]),
element_count
);
out
[
0
]
->
get_data_ptr
<
double
>
(),
element_count
);
break
;
break
;
case
element
:
:
Type_t
::
i8
:
case
element
:
:
Type_t
::
i8
:
reference
::
convert
<
T
>
(
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
int8_t
*>
(
out
[
0
]),
element_count
);
out
[
0
]
->
get_data_ptr
<
int8_t
>
(),
element_count
);
break
;
break
;
case
element
:
:
Type_t
::
i16
:
case
element
:
:
Type_t
::
i16
:
reference
::
convert
<
T
>
(
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
int16_t
*>
(
out
[
0
]),
element_count
);
out
[
0
]
->
get_data_ptr
<
int16_t
>
(),
element_count
);
break
;
break
;
case
element
:
:
Type_t
::
i32
:
case
element
:
:
Type_t
::
i32
:
reference
::
convert
<
T
>
(
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
int32_t
*>
(
out
[
0
]),
element_count
);
out
[
0
]
->
get_data_ptr
<
int32_t
>
(),
element_count
);
break
;
break
;
case
element
:
:
Type_t
::
i64
:
case
element
:
:
Type_t
::
i64
:
reference
::
convert
<
T
>
(
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
int64_t
*>
(
out
[
0
]),
element_count
);
out
[
0
]
->
get_data_ptr
<
int64_t
>
(),
element_count
);
break
;
break
;
case
element
:
:
Type_t
::
u8
:
case
element
:
:
Type_t
::
u8
:
reference
::
convert
<
T
>
(
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
uint8_t
*>
(
out
[
0
]),
element_count
);
out
[
0
]
->
get_data_ptr
<
uint8_t
>
(),
element_count
);
break
;
break
;
case
element
:
:
Type_t
::
u16
:
case
element
:
:
Type_t
::
u16
:
reference
::
convert
<
T
>
(
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
uint16_t
*>
(
out
[
0
]),
element_count
);
out
[
0
]
->
get_data_ptr
<
uint16_t
>
(),
element_count
);
break
;
break
;
case
element
:
:
Type_t
::
u32
:
case
element
:
:
Type_t
::
u32
:
reference
::
convert
<
T
>
(
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
uint32_t
*>
(
out
[
0
]),
element_count
);
out
[
0
]
->
get_data_ptr
<
uint32_t
>
(),
element_count
);
break
;
break
;
case
element
:
:
Type_t
::
u64
:
case
element
:
:
Type_t
::
u64
:
reference
::
convert
<
T
>
(
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
uint64_t
*>
(
out
[
0
]),
element_count
);
out
[
0
]
->
get_data_ptr
<
uint64_t
>
(),
element_count
);
break
;
break
;
case
element
:
:
Type_t
::
undefined
:
case
element
:
:
Type_t
::
undefined
:
case
element
:
:
Type_t
::
dynamic
:
case
element
:
:
Type_t
::
dynamic
:
case
element
:
:
Type_t
::
bf16
:
case
element
:
:
Type_t
::
bf16
:
case
element
:
:
Type_t
::
f16
:
ss
<<
"unsupported element type "
<<
type
<<
" op Convert"
;
ss
<<
"unsupported element type "
<<
type
<<
" op Convert"
;
throw
std
::
runtime_error
(
ss
.
str
());
throw
std
::
runtime_error
(
ss
.
str
());
}
}
...
@@ -559,9 +623,9 @@ private:
...
@@ -559,9 +623,9 @@ private:
case
OP_TYPEID
:
:
Convolution
:
case
OP_TYPEID
:
:
Convolution
:
{
{
const
op
::
Convolution
*
c
=
static_cast
<
const
op
::
Convolution
*>
(
&
node
);
const
op
::
Convolution
*
c
=
static_cast
<
const
op
::
Convolution
*>
(
&
node
);
reference
::
convolution
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
convolution
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_input_shape
(
1
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
...
@@ -569,38 +633,26 @@ private:
...
@@ -569,38 +633,26 @@ private:
c
->
get_window_dilation_strides
(),
c
->
get_window_dilation_strides
(),
c
->
get_padding_below
(),
c
->
get_padding_below
(),
c
->
get_padding_above
(),
c
->
get_padding_above
(),
c
->
get_data_dilation_strides
(),
c
->
get_data_dilation_strides
());
0
,
1
,
1
,
0
,
0
,
1
,
false
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
ConvolutionBackpropFilters
:
case
OP_TYPEID
:
:
ConvolutionBackpropFilters
:
{
{
const
op
::
ConvolutionBackpropFilters
*
c
=
const
op
::
ConvolutionBackpropFilters
*
c
=
static_cast
<
const
op
::
ConvolutionBackpropFilters
*>
(
&
node
);
static_cast
<
const
op
::
ConvolutionBackpropFilters
*>
(
&
node
);
reference
::
convolution
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
reference
::
convolution_backprop_filter
<
T
>
(
static_cast
<
const
T
*>
(
args
[
1
]),
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
// input
static_cast
<
T
*>
(
out
[
0
]),
args
[
1
]
->
get_data_ptr
<
const
T
>
(),
// delta_convolution_output
node
.
get_input_shape
(
0
),
out
[
0
]
->
get_data_ptr
<
T
>
(),
// delta_filter
node
.
get_input_shape
(
1
),
c
->
get_input_shape
(
0
),
// input_shape
node
.
get_output_shape
(
0
),
c
->
get_input_shape
(
1
),
// convolution_output_shape
c
->
get_window_movement_strides_backward
(),
c
->
get_filters_shape
(),
// filter_shape
c
->
get_window_dilation_strides_backward
(),
c
->
get_window_dilation_strides_forward
(),
c
->
get_padding_below_backward
(),
c
->
get_window_movement_strides_forward
(),
c
->
get_padding_above_backward
(),
c
->
get_padding_below_forward
(),
c
->
get_data_dilation_strides_backward
(),
c
->
compute_backward_in_pad_above
(),
1
,
c
->
get_data_dilation_strides_forward
());
0
,
0
,
1
,
1
,
0
,
false
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
ConvolutionBackpropData
:
case
OP_TYPEID
:
:
ConvolutionBackpropData
:
...
@@ -608,38 +660,31 @@ private:
...
@@ -608,38 +660,31 @@ private:
// Note that args[1] and args[0] are switched here from the usual order.
// Note that args[1] and args[0] are switched here from the usual order.
const
op
::
ConvolutionBackpropData
*
c
=
const
op
::
ConvolutionBackpropData
*
c
=
static_cast
<
const
op
::
ConvolutionBackpropData
*>
(
&
node
);
static_cast
<
const
op
::
ConvolutionBackpropData
*>
(
&
node
);
reference
::
convolution
<
T
>
(
static_cast
<
const
T
*>
(
args
[
1
]),
reference
::
convolution_backprop_in
<
T
>
(
args
[
1
]
->
get_data_ptr
<
const
T
>
(),
static_cast
<
const
T
*>
(
args
[
0
]),
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
static_cast
<
T
*>
(
out
[
0
]),
out
[
0
]
->
get_data_ptr
<
T
>
(),
node
.
get_input_shape
(
1
),
c
->
get_input_shape
(
1
),
node
.
get_input_shape
(
0
),
c
->
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
c
->
get_data_batch_shape
(),
c
->
get_window_movement_strides_backward
(),
c
->
get_data_dilation_strides_forward
(),
c
->
get_window_dilation_strides_backward
(),
c
->
get_window_dilation_strides_forward
(),
c
->
get_padding_below_backward
(),
c
->
compute_backward_delta_out_pad_below
(),
c
->
get_padding_above_backward
(),
c
->
compute_backward_delta_out_pad_above
(),
c
->
get_data_dilation_strides_backward
(),
c
->
get_window_movement_strides_forward
());
0
,
1
,
0
,
1
,
0
,
1
,
true
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Cos
:
case
OP_TYPEID
:
:
Cos
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
cos
<
T
>
(
reference
::
cos
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Cosh
:
case
OP_TYPEID
:
:
Cosh
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
cosh
<
T
>
(
reference
::
cosh
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Dequantize
:
case
OP_TYPEID
:
:
Dequantize
:
...
@@ -649,20 +694,20 @@ private:
...
@@ -649,20 +694,20 @@ private:
if
(
type
==
element
::
f32
)
if
(
type
==
element
::
f32
)
{
{
reference
::
dequantize
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
dequantize
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
float
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
float
>
(
),
static_cast
<
const
T
*>
(
args
[
2
]
),
args
[
2
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
float
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
float
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_input_shape
(
1
),
dequantize
->
get_axes
());
dequantize
->
get_axes
());
}
}
else
if
(
type
==
element
::
f64
)
else
if
(
type
==
element
::
f64
)
{
{
reference
::
dequantize
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
dequantize
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
double
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
double
>
(
),
static_cast
<
const
T
*>
(
args
[
2
]
),
args
[
2
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
double
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
double
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_input_shape
(
1
),
dequantize
->
get_axes
());
dequantize
->
get_axes
());
...
@@ -680,9 +725,9 @@ private:
...
@@ -680,9 +725,9 @@ private:
{
{
const
op
::
Divide
*
divop
=
static_cast
<
const
op
::
Divide
*>
(
&
node
);
const
op
::
Divide
*
divop
=
static_cast
<
const
op
::
Divide
*>
(
&
node
);
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
divide
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
divide
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
,
element_count
,
divop
->
is_pythondiv
());
divop
->
is_pythondiv
());
break
;
break
;
...
@@ -691,15 +736,25 @@ private:
...
@@ -691,15 +736,25 @@ private:
{
{
const
op
::
Dot
*
dot
=
static_cast
<
const
op
::
Dot
*>
(
&
node
);
const
op
::
Dot
*
dot
=
static_cast
<
const
op
::
Dot
*>
(
&
node
);
gcpu
::
kernel
::
dot
(
static_cast
<
const
T
*>
(
args
[
0
]
),
kernel
::
dot
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_input_shape
(
1
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
dot
->
get_reduction_axes_count
());
dot
->
get_reduction_axes_count
());
break
;
break
;
}
}
case
OP_TYPEID
:
:
DynReshape
:
{
throw
unsupported_op
(
"Unsupported op '"
+
node
.
description
()
+
"'"
);
break
;
}
case
OP_TYPEID
:
:
DynSlice
:
{
throw
unsupported_op
(
"Unsupported op '"
+
node
.
description
()
+
"'"
);
break
;
}
case
OP_TYPEID
:
:
EmbeddingLookup
:
case
OP_TYPEID
:
:
EmbeddingLookup
:
{
{
const
op
::
EmbeddingLookup
*
embed
=
static_cast
<
const
op
::
EmbeddingLookup
*>
(
&
node
);
const
op
::
EmbeddingLookup
*
embed
=
static_cast
<
const
op
::
EmbeddingLookup
*>
(
&
node
);
...
@@ -708,33 +763,33 @@ private:
...
@@ -708,33 +763,33 @@ private:
if
(
type
==
element
::
f32
)
if
(
type
==
element
::
f32
)
{
{
reference
::
embedding
<
T
,
float
>
(
static_cast
<
const
float
*>
(
args
[
0
]
),
reference
::
embedding
<
T
,
float
>
(
args
[
0
]
->
get_data_ptr
<
const
float
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
,
element_count
,
embed
->
get_shape
());
embed
->
get_shape
());
}
}
else
if
(
type
==
element
::
f64
)
else
if
(
type
==
element
::
f64
)
{
{
reference
::
embedding
<
T
,
double
>
(
static_cast
<
const
double
*>
(
args
[
0
]
),
reference
::
embedding
<
T
,
double
>
(
args
[
0
]
->
get_data_ptr
<
const
double
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
,
element_count
,
embed
->
get_shape
());
embed
->
get_shape
());
}
}
else
if
(
type
==
element
::
i32
)
else
if
(
type
==
element
::
i32
)
{
{
reference
::
embedding
<
T
,
int
>
(
static_cast
<
const
int
*>
(
args
[
0
]
),
reference
::
embedding
<
T
,
int
32_t
>
(
args
[
0
]
->
get_data_ptr
<
const
int
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
,
element_count
,
embed
->
get_shape
());
embed
->
get_shape
());
}
}
else
if
(
type
==
element
::
i64
)
else
if
(
type
==
element
::
i64
)
{
{
reference
::
embedding
<
T
,
int64_t
>
(
static_cast
<
const
int64_t
*>
(
args
[
0
]
),
reference
::
embedding
<
T
,
int64_t
>
(
args
[
0
]
->
get_data_ptr
<
const
int64_t
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
,
element_count
,
embed
->
get_shape
());
embed
->
get_shape
());
}
}
...
@@ -748,24 +803,56 @@ private:
...
@@ -748,24 +803,56 @@ private:
case
OP_TYPEID
:
:
Equal
:
case
OP_TYPEID
:
:
Equal
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
equal
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
equal
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
char
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
char
>
(
),
element_count
);
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Erf
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
erf
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Exp
:
case
OP_TYPEID
:
:
Exp
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
exp
<
T
>
(
reference
::
exp
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
element_count
);
break
;
}
#ifdef INTERPRETER_USE_HYBRID
case
OP_TYPEID
:
:
FunctionCall
:
{
auto
f
=
static_cast
<
const
runtime
::
hybrid
::
op
::
FunctionCall
*>
(
&
node
);
auto
backend
=
f
->
get_backend
();
auto
executable
=
f
->
get_executable
();
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
outputs
;
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
inputs
;
for
(
const
std
::
shared_ptr
<
HostTensor
>&
t
:
out
)
{
auto
backend_tensor
=
backend
->
create_tensor
(
t
->
get_element_type
(),
t
->
get_shape
(),
t
->
get_data_ptr
());
outputs
.
push_back
(
backend_tensor
);
}
for
(
const
std
::
shared_ptr
<
HostTensor
>&
t
:
args
)
{
auto
backend_tensor
=
backend
->
create_tensor
(
t
->
get_element_type
(),
t
->
get_shape
(),
t
->
get_data_ptr
());
inputs
.
push_back
(
backend_tensor
);
}
executable
->
call
(
outputs
,
inputs
);
break
;
break
;
}
}
#endif
case
OP_TYPEID
:
:
Floor
:
case
OP_TYPEID
:
:
Floor
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
floor
<
T
>
(
reference
::
floor
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Gather
:
case
OP_TYPEID
:
:
Gather
:
...
@@ -826,36 +913,36 @@ private:
...
@@ -826,36 +913,36 @@ private:
case
OP_TYPEID
:
:
Greater
:
case
OP_TYPEID
:
:
Greater
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
greater
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
greater
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
char
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
char
>
(
),
element_count
);
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
GreaterEq
:
case
OP_TYPEID
:
:
GreaterEq
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
greater_eq
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
greater_eq
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
char
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
char
>
(
),
element_count
);
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Less
:
case
OP_TYPEID
:
:
Less
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
less
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
less
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
char
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
char
>
(
),
element_count
);
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
LessEq
:
case
OP_TYPEID
:
:
LessEq
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
less_eq
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
less_eq
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
char
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
char
>
(
),
element_count
);
element_count
);
break
;
break
;
}
}
...
@@ -863,14 +950,14 @@ private:
...
@@ -863,14 +950,14 @@ private:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
log
<
T
>
(
reference
::
log
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
LRN
:
case
OP_TYPEID
:
:
LRN
:
{
{
const
op
::
LRN
*
lrn
=
static_cast
<
const
op
::
LRN
*>
(
&
node
);
const
op
::
LRN
*
lrn
=
static_cast
<
const
op
::
LRN
*>
(
&
node
);
reference
::
lrn
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
lrn
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
lrn
->
get_alpha
(),
lrn
->
get_alpha
(),
lrn
->
get_beta
(),
lrn
->
get_beta
(),
...
@@ -881,8 +968,8 @@ private:
...
@@ -881,8 +968,8 @@ private:
case
OP_TYPEID
:
:
Max
:
case
OP_TYPEID
:
:
Max
:
{
{
const
op
::
Max
*
max
=
static_cast
<
const
op
::
Max
*>
(
&
node
);
const
op
::
Max
*
max
=
static_cast
<
const
op
::
Max
*>
(
&
node
);
reference
::
max
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
max
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
max
->
get_reduction_axes
());
max
->
get_reduction_axes
());
...
@@ -891,9 +978,9 @@ private:
...
@@ -891,9 +978,9 @@ private:
case
OP_TYPEID
:
:
Maximum
:
case
OP_TYPEID
:
:
Maximum
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
maximum
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
maximum
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
element_count
);
break
;
break
;
}
}
...
@@ -901,8 +988,8 @@ private:
...
@@ -901,8 +988,8 @@ private:
{
{
const
op
::
MaxPool
*
max_pool
=
static_cast
<
const
op
::
MaxPool
*>
(
&
node
);
const
op
::
MaxPool
*
max_pool
=
static_cast
<
const
op
::
MaxPool
*>
(
&
node
);
reference
::
max_pool
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
max_pool
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
max_pool
->
get_window_shape
(),
max_pool
->
get_window_shape
(),
...
@@ -916,9 +1003,9 @@ private:
...
@@ -916,9 +1003,9 @@ private:
const
op
::
MaxPoolBackprop
*
max_pool_backprop
=
const
op
::
MaxPoolBackprop
*
max_pool_backprop
=
static_cast
<
const
op
::
MaxPoolBackprop
*>
(
&
node
);
static_cast
<
const
op
::
MaxPoolBackprop
*>
(
&
node
);
reference
::
max_pool_backprop
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
max_pool_backprop
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
1
),
node
.
get_input_shape
(
1
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
max_pool_backprop
->
get_window_shape
(),
max_pool_backprop
->
get_window_shape
(),
...
@@ -930,8 +1017,8 @@ private:
...
@@ -930,8 +1017,8 @@ private:
case
OP_TYPEID
:
:
Min
:
case
OP_TYPEID
:
:
Min
:
{
{
const
op
::
Min
*
min
=
static_cast
<
const
op
::
Min
*>
(
&
node
);
const
op
::
Min
*
min
=
static_cast
<
const
op
::
Min
*>
(
&
node
);
reference
::
min
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
min
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
min
->
get_reduction_axes
());
min
->
get_reduction_axes
());
...
@@ -940,18 +1027,18 @@ private:
...
@@ -940,18 +1027,18 @@ private:
case
OP_TYPEID
:
:
Minimum
:
case
OP_TYPEID
:
:
Minimum
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
minimum
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
minimum
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Multiply
:
case
OP_TYPEID
:
:
Multiply
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
multiply
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
multiply
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
element_count
);
break
;
break
;
}
}
...
@@ -959,30 +1046,30 @@ private:
...
@@ -959,30 +1046,30 @@ private:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
negate
<
T
>
(
reference
::
negate
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Not
:
case
OP_TYPEID
:
:
Not
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
logical_not
(
reference
::
logical_not
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
NotEqual
:
case
OP_TYPEID
:
:
NotEqual
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
not_equal
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
not_equal
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
char
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
char
>
(
),
element_count
);
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
OneHot
:
case
OP_TYPEID
:
:
OneHot
:
{
{
const
op
::
OneHot
*
oh
=
static_cast
<
const
op
::
OneHot
*>
(
&
node
);
const
op
::
OneHot
*
oh
=
static_cast
<
const
op
::
OneHot
*>
(
&
node
);
reference
::
one_hot
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
one_hot
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
oh
->
get_one_hot_axis
());
oh
->
get_one_hot_axis
());
...
@@ -991,46 +1078,46 @@ private:
...
@@ -991,46 +1078,46 @@ private:
case
OP_TYPEID
:
:
Or
:
case
OP_TYPEID
:
:
Or
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
logical_or
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
logical_or
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Parameter
:
break
;
case
OP_TYPEID
:
:
Parameter
:
break
;
case
OP_TYPEID
:
:
Passthrough
:
{
const
op
::
Passthrough
*
passthrough
=
static_cast
<
const
op
::
Passthrough
*>
(
&
node
);
throw
unsupported_op
{
"Unsupported operation language: "
+
passthrough
->
language
()};
}
case
OP_TYPEID
:
:
Pad
:
case
OP_TYPEID
:
:
Pad
:
{
{
const
op
::
Pad
*
pad
=
static_cast
<
const
op
::
Pad
*>
(
&
node
);
const
op
::
Pad
*
pad
=
static_cast
<
const
op
::
Pad
*>
(
&
node
);
reference
::
pad
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
pad
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_inputs
().
a
t
(
0
).
get_shape
(),
node
.
inpu
t
(
0
).
get_shape
(),
node
.
get_output_shape
(
0
),
node
.
output
(
0
).
get_shape
(
),
pad
->
get_padding_below
(),
pad
->
get_padding_below
(),
pad
->
get_padding_above
(),
pad
->
get_padding_above
(),
pad
->
get_pad
ding_interior
());
pad
->
get_pad
_mode
());
break
;
break
;
}
}
case
OP_TYPEID
:
:
Passthrough
:
{
const
op
::
Passthrough
*
passthrough
=
static_cast
<
const
op
::
Passthrough
*>
(
&
node
);
throw
unsupported_op
{
"Unsupported operation language: "
+
passthrough
->
language
()};
}
case
OP_TYPEID
:
:
Power
:
case
OP_TYPEID
:
:
Power
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
power
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
power
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Product
:
case
OP_TYPEID
:
:
Product
:
{
{
const
op
::
Product
*
product
=
static_cast
<
const
op
::
Product
*>
(
&
node
);
const
op
::
Product
*
product
=
static_cast
<
const
op
::
Product
*>
(
&
node
);
reference
::
product
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
product
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
product
->
get_reduction_axes
());
product
->
get_reduction_axes
());
...
@@ -1043,10 +1130,10 @@ private:
...
@@ -1043,10 +1130,10 @@ private:
if
(
type
==
element
::
u8
)
if
(
type
==
element
::
u8
)
{
{
reference
::
quantize
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
quantize
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
uint8_t
*>
(
args
[
2
]
),
args
[
2
]
->
get_data_ptr
<
const
uint8_t
>
(
),
static_cast
<
uint8_t
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
uint8_t
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_input_shape
(
1
),
quantize
->
get_axes
(),
quantize
->
get_axes
(),
...
@@ -1054,10 +1141,10 @@ private:
...
@@ -1054,10 +1141,10 @@ private:
}
}
else
if
(
type
==
element
::
i8
)
else
if
(
type
==
element
::
i8
)
{
{
reference
::
quantize
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
quantize
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
int8_t
*>
(
args
[
2
]
),
args
[
2
]
->
get_data_ptr
<
const
int8_t
>
(
),
static_cast
<
int8_t
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
int8_t
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_input_shape
(
1
),
quantize
->
get_axes
(),
quantize
->
get_axes
(),
...
@@ -1065,10 +1152,10 @@ private:
...
@@ -1065,10 +1152,10 @@ private:
}
}
else
if
(
type
==
element
::
i32
)
else
if
(
type
==
element
::
i32
)
{
{
reference
::
quantize
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
quantize
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
int32_t
*>
(
args
[
2
]
),
args
[
2
]
->
get_data_ptr
<
const
int32_t
>
(
),
static_cast
<
int32_t
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
int32_t
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_input_shape
(
1
),
quantize
->
get_axes
(),
quantize
->
get_axes
(),
...
@@ -1083,40 +1170,168 @@ private:
...
@@ -1083,40 +1170,168 @@ private:
break
;
break
;
}
}
case
OP_TYPEID
:
:
QuantizedConvolution
:
{
const
op
::
QuantizedConvolution
*
qc
=
static_cast
<
const
op
::
QuantizedConvolution
*>
(
&
node
);
auto
input_element_type
=
qc
->
get_input_element_type
(
0
);
auto
filter_element_type
=
qc
->
get_input_element_type
(
1
);
auto
output_element_type
=
qc
->
get_output_element_type
(
0
);
if
(
input_element_type
==
element
::
u8
&&
filter_element_type
==
element
::
i8
&&
output_element_type
==
element
::
i8
)
{
reference
::
convolution
<
uint8_t
,
int8_t
,
int8_t
,
int32_t
>
(
args
[
0
]
->
get_data_ptr
<
const
uint8_t
>
(),
args
[
1
]
->
get_data_ptr
<
const
int8_t
>
(),
out
[
0
]
->
get_data_ptr
<
int8_t
>
(),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_output_shape
(
0
),
qc
->
get_window_movement_strides
(),
qc
->
get_window_dilation_strides
(),
qc
->
get_padding_below
(),
qc
->
get_padding_above
(),
qc
->
get_data_dilation_strides
(),
args
[
2
]
->
get_data_ptr
<
const
float
>
(),
args
[
3
]
->
get_data_ptr
<
const
uint8_t
>
(),
args
[
4
]
->
get_data_ptr
<
const
float
>
(),
args
[
5
]
->
get_data_ptr
<
const
int8_t
>
(),
args
[
6
]
->
get_data_ptr
<
const
float
>
(),
args
[
7
]
->
get_data_ptr
<
const
int8_t
>
());
}
else
if
(
input_element_type
==
element
::
u8
&&
filter_element_type
==
element
::
u8
&&
output_element_type
==
element
::
u8
)
{
reference
::
convolution
<
uint8_t
,
uint8_t
,
uint8_t
,
int32_t
>
(
args
[
0
]
->
get_data_ptr
<
const
uint8_t
>
(),
args
[
1
]
->
get_data_ptr
<
const
uint8_t
>
(),
out
[
0
]
->
get_data_ptr
<
uint8_t
>
(),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_output_shape
(
0
),
qc
->
get_window_movement_strides
(),
qc
->
get_window_dilation_strides
(),
qc
->
get_padding_below
(),
qc
->
get_padding_above
(),
qc
->
get_data_dilation_strides
(),
args
[
2
]
->
get_data_ptr
<
const
float
>
(),
args
[
3
]
->
get_data_ptr
<
const
uint8_t
>
(),
args
[
4
]
->
get_data_ptr
<
const
float
>
(),
args
[
5
]
->
get_data_ptr
<
const
uint8_t
>
(),
args
[
6
]
->
get_data_ptr
<
const
float
>
(),
args
[
7
]
->
get_data_ptr
<
const
uint8_t
>
());
}
else
if
(
input_element_type
==
element
::
u8
&&
filter_element_type
==
element
::
i8
&&
output_element_type
==
element
::
i32
)
{
reference
::
convolution
<
uint8_t
,
int8_t
,
int32_t
,
int32_t
>
(
args
[
0
]
->
get_data_ptr
<
const
uint8_t
>
(),
args
[
1
]
->
get_data_ptr
<
const
int8_t
>
(),
out
[
0
]
->
get_data_ptr
<
int32_t
>
(),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_output_shape
(
0
),
qc
->
get_window_movement_strides
(),
qc
->
get_window_dilation_strides
(),
qc
->
get_padding_below
(),
qc
->
get_padding_above
(),
qc
->
get_data_dilation_strides
(),
args
[
2
]
->
get_data_ptr
<
const
float
>
(),
args
[
3
]
->
get_data_ptr
<
const
uint8_t
>
(),
args
[
4
]
->
get_data_ptr
<
const
float
>
(),
args
[
5
]
->
get_data_ptr
<
const
int8_t
>
(),
args
[
6
]
->
get_data_ptr
<
const
float
>
(),
args
[
7
]
->
get_data_ptr
<
const
int32_t
>
());
}
else
if
(
input_element_type
==
element
::
u8
&&
filter_element_type
==
element
::
u8
&&
output_element_type
==
element
::
i32
)
{
reference
::
convolution
<
uint8_t
,
uint8_t
,
int32_t
,
int32_t
>
(
args
[
0
]
->
get_data_ptr
<
const
uint8_t
>
(),
args
[
1
]
->
get_data_ptr
<
const
uint8_t
>
(),
out
[
0
]
->
get_data_ptr
<
int32_t
>
(),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_output_shape
(
0
),
qc
->
get_window_movement_strides
(),
qc
->
get_window_dilation_strides
(),
qc
->
get_padding_below
(),
qc
->
get_padding_above
(),
qc
->
get_data_dilation_strides
(),
args
[
2
]
->
get_data_ptr
<
const
float
>
(),
args
[
3
]
->
get_data_ptr
<
const
uint8_t
>
(),
args
[
4
]
->
get_data_ptr
<
const
float
>
(),
args
[
5
]
->
get_data_ptr
<
const
uint8_t
>
(),
args
[
6
]
->
get_data_ptr
<
const
float
>
(),
args
[
7
]
->
get_data_ptr
<
const
int32_t
>
());
}
else
{
std
::
stringstream
ss
;
ss
<<
"unsupported element type"
;
throw
std
::
runtime_error
(
ss
.
str
());
}
break
;
}
case
OP_TYPEID
:
:
QuantizedAvgPool
:
case
OP_TYPEID
:
:
QuantizedAvgPool
:
case
OP_TYPEID
:
:
QuantizedConvolutionBias
:
case
OP_TYPEID
:
:
QuantizedConvolutionBias
:
case
OP_TYPEID
:
:
QuantizedConvolutionBiasAdd
:
case
OP_TYPEID
:
:
QuantizedConvolutionBiasAdd
:
case
OP_TYPEID
:
:
QuantizedConvolutionBiasSignedAdd
:
case
OP_TYPEID
:
:
QuantizedConvolutionBiasSignedAdd
:
case
OP_TYPEID
:
:
QuantizedConvolutionRelu
:
case
OP_TYPEID
:
:
QuantizedConvolutionRelu
:
case
OP_TYPEID
:
:
QuantizedConvolution
:
case
OP_TYPEID
:
:
QuantizedMaxPool
:
case
OP_TYPEID
:
:
QuantizedMaxPool
:
case
OP_TYPEID
:
:
QuantizedDotBias
:
case
OP_TYPEID
:
:
QuantizedDotBias
:
case
OP_TYPEID
:
:
QuantizedDot
:
case
OP_TYPEID
:
:
QuantizedDot
:
{
{
throw
unsupported_op
(
"Unsupported op '"
+
node
.
description
()
+
"'."
);
throw
unsupported_op
(
"Unsupported op '"
+
node
.
description
()
+
"' in Interpreter back end."
);
}
case
OP_TYPEID
:
:
Recv
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
memSize
=
element_count
*
sizeof
(
T
);
const
auto
*
op
=
static_cast
<
const
ngraph
::
op
::
Recv
*>
(
&
node
);
int
src_id
=
op
->
get_src_id
();
reference
::
recv
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
node
.
get_input_element_type
(
0
).
get_type_enum
(),
element_count
,
src_id
);
memcpy
(
out
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
0
]
->
get_data_ptr
<
T
>
(),
memSize
);
break
;
}
case
OP_TYPEID
:
:
Range
:
{
throw
unsupported_op
(
"Unsupported op '"
+
node
.
description
()
+
"'"
);
break
;
}
}
case
OP_TYPEID
:
:
Relu
:
case
OP_TYPEID
:
:
Relu
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
relu
<
T
>
(
reference
::
relu
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
ReluBackprop
:
case
OP_TYPEID
:
:
ReluBackprop
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
relu_backprop
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
relu_backprop
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
ReplaceSlice
:
case
OP_TYPEID
:
:
ReplaceSlice
:
{
{
const
op
::
ReplaceSlice
*
slice
=
static_cast
<
const
op
::
ReplaceSlice
*>
(
&
node
);
const
op
::
ReplaceSlice
*
slice
=
static_cast
<
const
op
::
ReplaceSlice
*>
(
&
node
);
reference
::
replace_slice
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
replace_slice
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
1
),
node
.
get_input_shape
(
1
),
slice
->
get_lower_bounds
(),
slice
->
get_lower_bounds
(),
slice
->
get_upper_bounds
(),
slice
->
get_upper_bounds
(),
...
@@ -1127,8 +1342,8 @@ private:
...
@@ -1127,8 +1342,8 @@ private:
case
OP_TYPEID
:
:
Reshape
:
case
OP_TYPEID
:
:
Reshape
:
{
{
const
op
::
Reshape
*
reshape
=
static_cast
<
const
op
::
Reshape
*>
(
&
node
);
const
op
::
Reshape
*
reshape
=
static_cast
<
const
op
::
Reshape
*>
(
&
node
);
gcpu
::
kernel
::
reshape
(
static_cast
<
const
T
*>
(
args
[
0
]
),
kernel
::
reshape
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
reshape
->
get_input_order
(),
reshape
->
get_input_order
(),
node
.
get_output_shape
(
0
));
node
.
get_output_shape
(
0
));
...
@@ -1137,16 +1352,16 @@ private:
...
@@ -1137,16 +1352,16 @@ private:
case
OP_TYPEID
:
:
Result
:
case
OP_TYPEID
:
:
Result
:
{
{
const
op
::
Result
*
res
=
static_cast
<
const
op
::
Result
*>
(
&
node
);
const
op
::
Result
*
res
=
static_cast
<
const
op
::
Result
*>
(
&
node
);
reference
::
result
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
result
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
shape_size
(
res
->
get_shape
()));
shape_size
(
res
->
get_shape
()));
break
;
break
;
}
}
case
OP_TYPEID
:
:
Reverse
:
case
OP_TYPEID
:
:
Reverse
:
{
{
const
op
::
Reverse
*
reverse
=
static_cast
<
const
op
::
Reverse
*>
(
&
node
);
const
op
::
Reverse
*
reverse
=
static_cast
<
const
op
::
Reverse
*>
(
&
node
);
reference
::
reverse
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
reverse
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
reverse
->
get_reversed_axes
());
reverse
->
get_reversed_axes
());
...
@@ -1158,12 +1373,12 @@ private:
...
@@ -1158,12 +1373,12 @@ private:
if
(
node
.
get_input_element_type
(
1
)
==
element
::
i32
)
if
(
node
.
get_input_element_type
(
1
)
==
element
::
i32
)
{
{
reference
::
reverse_sequence
<
T
,
int32_t
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
reverse_sequence
<
T
,
int32_t
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
reverse
->
get_batch_axis
(),
reverse
->
get_batch_axis
(),
reverse
->
get_sequence_axis
(),
reverse
->
get_sequence_axis
(),
static_cast
<
const
int32_t
*>
(
args
[
1
]
));
args
[
1
]
->
get_data_ptr
<
const
int32_t
>
(
));
}
}
else
else
{
{
...
@@ -1234,31 +1449,46 @@ private:
...
@@ -1234,31 +1449,46 @@ private:
case
OP_TYPEID
:
:
Select
:
case
OP_TYPEID
:
:
Select
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
select
<
T
>
(
static_cast
<
const
char
*>
(
args
[
0
]
),
reference
::
select
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
char
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
2
]
),
args
[
2
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Send
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
memSize
=
element_count
*
sizeof
(
T
);
const
auto
*
op
=
static_cast
<
const
ngraph
::
op
::
Send
*>
(
&
node
);
int
dest_id
=
op
->
get_dest_id
();
reference
::
send
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
node
.
get_input_element_type
(
0
).
get_type_enum
(),
element_count
,
dest_id
);
memcpy
(
out
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
0
]
->
get_data_ptr
<
T
>
(),
memSize
);
break
;
}
case
OP_TYPEID
:
:
ShapeOf
:
case
OP_TYPEID
:
:
ShapeOf
:
{
{
reference
::
shape_of
(
node
.
get_input_shape
(
0
),
static_cast
<
uint64_t
*>
(
out
[
0
]
));
reference
::
shape_of
(
node
.
get_input_shape
(
0
),
out
[
0
]
->
get_data_ptr
<
uint64_t
>
(
));
break
;
break
;
}
}
case
OP_TYPEID
:
:
Sigmoid
:
case
OP_TYPEID
:
:
Sigmoid
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
sigmoid
<
T
>
(
reference
::
sigmoid
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
SigmoidBackprop
:
case
OP_TYPEID
:
:
SigmoidBackprop
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
sigmoid_backprop
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
sigmoid_backprop
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
element_count
);
break
;
break
;
}
}
...
@@ -1266,28 +1496,28 @@ private:
...
@@ -1266,28 +1496,28 @@ private:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
sign
<
T
>
(
reference
::
sign
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Sin
:
case
OP_TYPEID
:
:
Sin
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
sin
<
T
>
(
reference
::
sin
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Sinh
:
case
OP_TYPEID
:
:
Sinh
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
sinh
<
T
>
(
reference
::
sinh
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Slice
:
case
OP_TYPEID
:
:
Slice
:
{
{
const
op
::
Slice
*
slice
=
static_cast
<
const
op
::
Slice
*>
(
&
node
);
const
op
::
Slice
*
slice
=
static_cast
<
const
op
::
Slice
*>
(
&
node
);
reference
::
slice
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
slice
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
slice
->
get_lower_bounds
(),
slice
->
get_lower_bounds
(),
slice
->
get_upper_bounds
(),
slice
->
get_upper_bounds
(),
...
@@ -1298,8 +1528,8 @@ private:
...
@@ -1298,8 +1528,8 @@ private:
case
OP_TYPEID
:
:
Softmax
:
case
OP_TYPEID
:
:
Softmax
:
{
{
const
op
::
Softmax
*
softmax
=
static_cast
<
const
op
::
Softmax
*>
(
&
node
);
const
op
::
Softmax
*
softmax
=
static_cast
<
const
op
::
Softmax
*>
(
&
node
);
reference
::
softmax
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
softmax
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
softmax
->
get_axes
());
softmax
->
get_axes
());
break
;
break
;
...
@@ -1308,7 +1538,7 @@ private:
...
@@ -1308,7 +1538,7 @@ private:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
sqrt
<
T
>
(
reference
::
sqrt
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
StopGradient
:
{
throw
unsupported_op
(
"Unsupported op 'StopGradient'"
);
case
OP_TYPEID
:
:
StopGradient
:
{
throw
unsupported_op
(
"Unsupported op 'StopGradient'"
);
...
@@ -1316,17 +1546,17 @@ private:
...
@@ -1316,17 +1546,17 @@ private:
case
OP_TYPEID
:
:
Subtract
:
case
OP_TYPEID
:
:
Subtract
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
subtract
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
subtract
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
const
T
*>
(
args
[
1
]
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Sum
:
case
OP_TYPEID
:
:
Sum
:
{
{
const
op
::
Sum
*
sum
=
static_cast
<
const
op
::
Sum
*>
(
&
node
);
const
op
::
Sum
*
sum
=
static_cast
<
const
op
::
Sum
*>
(
&
node
);
reference
::
sum
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
sum
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
sum
->
get_reduction_axes
());
sum
->
get_reduction_axes
());
...
@@ -1336,14 +1566,14 @@ private:
...
@@ -1336,14 +1566,14 @@ private:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
tan
<
T
>
(
reference
::
tan
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
Tanh
:
case
OP_TYPEID
:
:
Tanh
:
{
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
tanh
<
T
>
(
reference
::
tanh
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
break
;
}
}
case
OP_TYPEID
:
:
TopK
:
case
OP_TYPEID
:
:
TopK
:
...
@@ -1351,9 +1581,9 @@ private:
...
@@ -1351,9 +1581,9 @@ private:
const
op
::
TopK
*
topk
=
static_cast
<
const
op
::
TopK
*>
(
&
node
);
const
op
::
TopK
*
topk
=
static_cast
<
const
op
::
TopK
*>
(
&
node
);
if
(
node
.
get_output_element_type
(
0
)
==
element
::
i64
)
if
(
node
.
get_output_element_type
(
0
)
==
element
::
i64
)
{
{
reference
::
topk
<
T
,
int64_t
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
topk
<
T
,
int64_t
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
int64_t
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
int64_t
>
(
),
static_cast
<
T
*>
(
out
[
1
]
),
out
[
1
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
topk
->
get_top_k_axis
(),
topk
->
get_top_k_axis
(),
...
@@ -1362,9 +1592,9 @@ private:
...
@@ -1362,9 +1592,9 @@ private:
}
}
else
if
(
node
.
get_output_element_type
(
0
)
==
element
::
i32
)
else
if
(
node
.
get_output_element_type
(
0
)
==
element
::
i32
)
{
{
reference
::
topk
<
T
,
int32_t
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
reference
::
topk
<
T
,
int32_t
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
static_cast
<
int32_t
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
int32_t
>
(
),
static_cast
<
T
*>
(
out
[
1
]
),
out
[
1
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
node
.
get_output_shape
(
0
),
topk
->
get_top_k_axis
(),
topk
->
get_top_k_axis
(),
...
@@ -1377,7 +1607,12 @@ private:
...
@@ -1377,7 +1607,12 @@ private:
}
}
break
;
break
;
}
}
default
:
throw
unsupported_op
(
"Unsupported op '"
+
node
.
description
()
+
"'"
);
case
OP_TYPEID
:
:
DynBroadcast
:
case
OP_TYPEID
:
:
Transpose
:
case
OP_TYPEID
:
:
DynPad
:
case
OP_TYPEID
:
:
Tile
:
case
OP_TYPEID
:
:
DynReplaceSlice
:
throw
unsupported_op
(
"Unsupported op '"
+
node
.
description
()
+
"'"
);
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic pop
#pragma GCC diagnostic pop
#endif
#endif
...
...
src/ngraph/runtime/generic_cpu/kernel/broadcast.hpp
View file @
334a55fa
...
@@ -140,6 +140,91 @@ namespace ngraph
...
@@ -140,6 +140,91 @@ namespace ngraph
}
}
}
}
template
<
typename
T
>
void
broadcast_5d
(
const
T
*
in
,
T
*
out
,
const
Shape
&
in_shape
,
const
Shape
&
out_shape
,
const
AxisSet
&
broadcast_axes
)
{
size_t
index
[
5
];
size_t
*
out_index
=
0
;
for
(
size_t
i
=
0
;
i
<
5
;
i
++
)
{
if
(
broadcast_axes
.
count
(
i
)
==
0
)
{
out_index
=
&
index
[
i
];
break
;
}
}
for
(
index
[
0
]
=
0
;
index
[
0
]
<
out_shape
[
0
];
++
index
[
0
])
{
for
(
index
[
1
]
=
0
;
index
[
1
]
<
out_shape
[
1
];
++
index
[
1
])
{
for
(
index
[
2
]
=
0
;
index
[
2
]
<
out_shape
[
2
];
++
index
[
2
])
{
for
(
index
[
3
]
=
0
;
index
[
3
]
<
out_shape
[
3
];
++
index
[
3
])
{
for
(
index
[
4
]
=
0
;
index
[
4
]
<
out_shape
[
4
];
++
index
[
4
])
{
out
[
index
[
0
]
*
out_shape
[
1
]
*
out_shape
[
2
]
*
out_shape
[
3
]
*
out_shape
[
4
]
+
index
[
1
]
*
out_shape
[
2
]
*
out_shape
[
3
]
*
out_shape
[
4
]
+
index
[
2
]
*
out_shape
[
3
]
*
out_shape
[
4
]
+
index
[
3
]
*
out_shape
[
4
]
+
index
[
4
]]
=
in
[
*
out_index
];
}
}
}
}
}
}
template
<
typename
T
>
void
broadcast_6d
(
const
T
*
in
,
T
*
out
,
const
Shape
&
in_shape
,
const
Shape
&
out_shape
,
const
AxisSet
&
broadcast_axes
)
{
size_t
index
[
6
];
size_t
*
out_index
=
0
;
for
(
size_t
i
=
0
;
i
<
6
;
i
++
)
{
if
(
broadcast_axes
.
count
(
i
)
==
0
)
{
out_index
=
&
index
[
i
];
break
;
}
}
for
(
index
[
0
]
=
0
;
index
[
0
]
<
out_shape
[
0
];
++
index
[
0
])
{
for
(
index
[
1
]
=
0
;
index
[
1
]
<
out_shape
[
1
];
++
index
[
1
])
{
for
(
index
[
2
]
=
0
;
index
[
2
]
<
out_shape
[
2
];
++
index
[
2
])
{
for
(
index
[
3
]
=
0
;
index
[
3
]
<
out_shape
[
3
];
++
index
[
3
])
{
for
(
index
[
4
]
=
0
;
index
[
4
]
<
out_shape
[
4
];
++
index
[
4
])
{
for
(
index
[
5
]
=
0
;
index
[
5
]
<
out_shape
[
5
];
++
index
[
5
])
{
out
[
index
[
0
]
*
out_shape
[
1
]
*
out_shape
[
2
]
*
out_shape
[
3
]
*
out_shape
[
4
]
*
out_shape
[
5
]
+
index
[
1
]
*
out_shape
[
2
]
*
out_shape
[
3
]
*
out_shape
[
4
]
*
out_shape
[
5
]
+
index
[
2
]
*
out_shape
[
3
]
*
out_shape
[
4
]
*
out_shape
[
5
]
+
index
[
3
]
*
out_shape
[
4
]
*
out_shape
[
5
]
+
index
[
4
]
*
out_shape
[
5
]
+
index
[
5
]]
=
in
[
*
out_index
];
}
}
}
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
void
broadcast
(
const
T
*
in
,
void
broadcast
(
const
T
*
in
,
T
*
out
,
T
*
out
,
...
@@ -167,6 +252,16 @@ namespace ngraph
...
@@ -167,6 +252,16 @@ namespace ngraph
case
4
:
case
4
:
broadcast_4d
<
T
>
(
in
,
out
,
in_shape
,
out_shape
,
broadcast_axes
);
broadcast_4d
<
T
>
(
in
,
out
,
in_shape
,
out_shape
,
broadcast_axes
);
break
;
break
;
case
5
:
broadcast_5d
<
T
>
(
in
,
out
,
in_shape
,
out_shape
,
broadcast_axes
);
break
;
case
6
:
broadcast_6d
<
T
>
(
in
,
out
,
in_shape
,
out_shape
,
broadcast_axes
);
break
;
default
:
runtime
::
reference
::
broadcast
<
T
>
(
in
,
out
,
in_shape
,
out_shape
,
broadcast_axes
);
break
;
}
}
}
}
else
else
...
...
src/ngraph/runtime/generic_cpu/kernel/reshape.hpp
View file @
334a55fa
...
@@ -244,10 +244,7 @@ namespace ngraph
...
@@ -244,10 +244,7 @@ namespace ngraph
case
4
:
reshape_in4
<
T
>
(
in
,
out
,
in_shape
,
in_axis_order
,
out_shape
);
break
;
case
4
:
reshape_in4
<
T
>
(
in
,
out
,
in_shape
,
in_axis_order
,
out_shape
);
break
;
case
5
:
reshape_in5
<
T
>
(
in
,
out
,
in_shape
,
in_axis_order
,
out_shape
);
break
;
case
5
:
reshape_in5
<
T
>
(
in
,
out
,
in_shape
,
in_axis_order
,
out_shape
);
break
;
case
6
:
reshape_in6
<
T
>
(
in
,
out
,
in_shape
,
in_axis_order
,
out_shape
);
break
;
case
6
:
reshape_in6
<
T
>
(
in
,
out
,
in_shape
,
in_axis_order
,
out_shape
);
break
;
default
:
default
:
reference
::
reshape
(
in
,
out
,
in_shape
,
in_axis_order
,
out_shape
);
break
;
NGRAPH_INFO
<<
"reference::reshape"
;
reference
::
reshape
(
in
,
out
,
in_shape
,
in_axis_order
,
out_shape
);
break
;
}
}
}
}
}
}
...
...
src/ngraph/runtime/generic_cpu/kernel/result.hpp
deleted
100644 → 0
View file @
5cfe1075
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <algorithm>
#include <cmath>
#include <numeric>
#include <vector>
#include "ngraph/shape.hpp"
namespace
ngraph
{
namespace
runtime
{
namespace
gcpu
{
namespace
kernel
{
template
<
typename
T
>
void
result
(
const
T
*
arg
,
T
*
out
,
size_t
count
)
{
memcpy
(
out
,
arg
,
sizeof
(
T
)
*
count
);
}
}
}
}
}
src/ngraph/runtime/generic_cpu/node_wrapper.hpp
View file @
334a55fa
...
@@ -51,7 +51,7 @@ class ngraph::runtime::gcpu::NodeWrapper
...
@@ -51,7 +51,7 @@ class ngraph::runtime::gcpu::NodeWrapper
public
:
public
:
NodeWrapper
(
const
std
::
shared_ptr
<
const
ngraph
::
Node
>&
node
);
NodeWrapper
(
const
std
::
shared_ptr
<
const
ngraph
::
Node
>&
node
);
const
Node
&
get_node
()
const
{
return
*
m_node
;
}
std
::
shared_ptr
<
const
Node
>
get_node
()
const
{
return
m_node
;
}
ngraph
::
runtime
::
gcpu
::
OP_TYPEID
get_typeid
()
const
{
return
m_typeid
;
}
ngraph
::
runtime
::
gcpu
::
OP_TYPEID
get_typeid
()
const
{
return
m_typeid
;
}
private
:
private
:
std
::
shared_ptr
<
const
ngraph
::
Node
>
m_node
;
std
::
shared_ptr
<
const
ngraph
::
Node
>
m_node
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment