Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in / Register
Toggle navigation
N
ngraph
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Packages
Packages
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
submodule
ngraph
Commits
60252edd
Unverified
Commit
60252edd
authored
Jul 09, 2019
by
Scott Cyphers
Committed by
GitHub
Jul 09, 2019
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' into ayzhuang/batch_norm_infer_relu_fusion
parents
341205cf
47342339
Show whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
719 additions
and
528 deletions
+719
-528
external_mkldnn.cmake
cmake/external_mkldnn.cmake
+3
-3
mkldnn.patch
cmake/mkldnn.patch
+0
-13
conf.py
doc/sphinx/conf.py
+2
-4
ngversions.html
doc/sphinx/ngraph_theme/ngversions.html
+2
-2
release-notes.rst
doc/sphinx/source/project/release-notes.rst
+22
-39
test_requirements.txt
python/test_requirements.txt
+1
-0
CMakeLists.txt
src/ngraph/CMakeLists.txt
+0
-1
reshape.hpp
src/ngraph/op/util/reshape.hpp
+0
-81
allocator.cpp
src/ngraph/runtime/allocator.cpp
+3
-2
allocator.hpp
src/ngraph/runtime/allocator.hpp
+1
-1
cpu_backend.cpp
src/ngraph/runtime/cpu/cpu_backend.cpp
+1
-1
CMakeLists.txt
src/ngraph/runtime/generic_cpu/CMakeLists.txt
+4
-4
gcpu_backend.cpp
src/ngraph/runtime/generic_cpu/gcpu_backend.cpp
+2
-2
gcpu_executable.cpp
src/ngraph/runtime/generic_cpu/gcpu_executable.cpp
+47
-30
gcpu_executable.hpp
src/ngraph/runtime/generic_cpu/gcpu_executable.hpp
+534
-299
broadcast.hpp
src/ngraph/runtime/generic_cpu/kernel/broadcast.hpp
+95
-0
reshape.hpp
src/ngraph/runtime/generic_cpu/kernel/reshape.hpp
+1
-4
result.hpp
src/ngraph/runtime/generic_cpu/kernel/result.hpp
+0
-41
node_wrapper.hpp
src/ngraph/runtime/generic_cpu/node_wrapper.hpp
+1
-1
No files found.
cmake/external_mkldnn.cmake
View file @
60252edd
...
...
@@ -18,10 +18,10 @@ include(ExternalProject)
# Includes blas 3.8.0 in mkldnn
set
(
NGRAPH_MKLDNN_SHORT_VERSION 0
)
set
(
NGRAPH_MKLDNN_FULL_VERSION 0.
19
.0.0
)
set
(
NGRAPH_MKLDNN_VERSION
"v0.
19
"
)
set
(
NGRAPH_MKLDNN_FULL_VERSION 0.
20
.0.0
)
set
(
NGRAPH_MKLDNN_VERSION
"v0.
20
"
)
set
(
NGRAPH_MKLDNN_SUB_VERSION
"2019.0.5.20190502"
)
set
(
NGRAPH_MKLDNN_GIT_TAG
"
027de76
"
)
set
(
NGRAPH_MKLDNN_GIT_TAG
"
v0.20
"
)
#------------------------------------------------------------------------------
# Fetch and install MKL-DNN
...
...
cmake/mkldnn.patch
View file @
60252edd
...
...
@@ -28,16 +28,3 @@ index f10feb20..05f47961 100644
set_property(TARGET ${LIB_NAME} PROPERTY PUBLIC_HEADER ${HEADERS})
target_include_directories(${LIB_NAME} PUBLIC
diff --git a/src/cpu/jit_avx512_common_conv_kernel.cpp b/src/cpu/jit_avx512_common_conv_kernel.cpp
index 1bb98fa43..b8b54401f 100644
--- a/src/cpu/jit_avx512_common_conv_kernel.cpp
+++ b/src/cpu/jit_avx512_common_conv_kernel.cpp
@@ -3055,7 +3055,7 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32::bias_kernel_3d() {
void jit_avx512_common_conv_bwd_weights_kernel_f32
::compute_oh_loop_common()
{
- assert(jcp.harness == harness_mb_reduction);
+ assert(one_of(jcp.harness, harness_mb_reduction, harness_3d_reduction));
int b_pad = jcp.b_pad;
int t_pad = jcp.t_pad;
bool is_dilated = jcp.dilate_h != 0;
doc/sphinx/conf.py
View file @
60252edd
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
...
...
@@ -73,11 +71,11 @@ author = 'Intel Corporation'
# built documents.
#
# The short X.Y version.
version
=
'0.2
2
'
version
=
'0.2
3
'
# The Documentation full version, including alpha/beta/rc tags. Some features
# available in the latest code will not necessarily be documented first
release
=
'0.2
2
.0'
release
=
'0.2
3
.0'
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
...
...
doc/sphinx/ngraph_theme/ngversions.html
View file @
60252edd
...
...
@@ -9,11 +9,11 @@
<dt>
{{ _('Recent Versions') }}
</dt>
<dd>
<!-- Until our https://docs.ngraph.ai/ publishing is set up, we link to GitHub -->
<ul>
<li><a
href=
"https://github.com/NervanaSystems/ngraph/releases/tag/v0.22.0"
>
0.22
</a></li>
<li><a
href=
"https://github.com/NervanaSystems/ngraph/releases/tag/v0.23.0"
>
0.23.0
</a></li>
<li><a
href=
"https://github.com/NervanaSystems/ngraph/releases/tag/v0.22.0"
>
0.22.0
</a></li>
<li><a
href=
"https://github.com/NervanaSystems/ngraph/releases/tag/v0.21.0"
>
0.21.0
</a></li>
<li><a
href=
"https://github.com/NervanaSystems/ngraph/releases/tag/v0.20.0"
>
0.20.0
</a></li>
<li><a
href=
"https://github.com/NervanaSystems/ngraph/releases/tag/v0.19.0"
>
0.19.0
</a></li>
<li><a
href=
"https://github.com/NervanaSystems/ngraph/releases/tag/v0.18.1"
>
0.18.1
</a></li>
</ul></dd>
</dl>
<dl>
...
...
doc/sphinx/source/project/release-notes.rst
View file @
60252edd
...
...
@@ -6,28 +6,30 @@ Release Notes
nGraph is provided as source code, APIs, build scripts, and some binary formats
for various Compiler stack configurations and use cases.
For downloads formatted as ``.zip`` and ``tar.gz``, see
https://github.com/NervanaSystems/ngraph/releases.
This page includes additional documentation updates.
We are pleased to announce the release of version |version|-doc.
==============================
Core updates for |version|
~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ PlaidML support
+ More ONNX ops
+
Optimization
s
+
Don't reseed RNG on each use
+
Elementwise divide defaults to Python semantic
s
+
GenerateMask seed optional
0.22-doc
--------
+ Initial doc and API for IntelGPU backend.
+ DynamicBackend API.
+ Note deprecation of support of MXNet's ``ngraph-mxnet`` PyPI.
+ Noted changes on graph inspection options resultant from PR 3016.
+ Added better tips and details to doc-contributor-README.
Latest doc updates
~~~~~~~~~~~~~~~~~~
+ Document new debug tool
+ Note deprecation of MXNet's ``ngraph-mxnet`` PyPI
+ Note default change to `svg` files for graphs and visualization
+ Add more prominent tips for contributors who find the doc-contributor-README
.. important:: Pre-releases (``-rc-0.*``) have newer features, and are less stable.
...
...
@@ -36,8 +38,15 @@ Core updates for |version|
Changelog on Previous Releases
==============================
For downloads formatted as ``.zip`` and ``tar.gz``, see
https://github.com/NervanaSystems/ngraph/releases.
0.22
----
+ More ONNX ops
+ Optimizations
+ Don't reseed RNG on each use
+ Initial doc and API for IntelGPU backend
+ DynamicBackend API
0.21
----
...
...
@@ -51,12 +60,6 @@ https://github.com/NervanaSystems/ngraph/releases.
+ offset arg for tensor creation is deprecated
+ static linking support
+ Initial test of 0.21-doc
0.21-doc
--------
Summary of documentation-related changes:
+ Updated :doc:`doc-contributor-README` for new community-based contributions.
+ Added instructions on how to test or display the installed nGraph version.
+ Added instructions on building nGraph bridge (ngraph-bridge).
...
...
@@ -82,8 +85,6 @@ Summary of documentation-related changes:
0.19
----
**Download** `0.19.0-rc.2`_
+ More dynamic shape preparation
+ Distributed interface factored out
+ fp16 and bfloat16 types
...
...
@@ -103,9 +104,6 @@ Summary of documentation-related changes:
0.18
----
**Download** `0.18.1`_
+ Python formatting issue
+ mkl-dnn work-around
+ Event tracing improvements
...
...
@@ -118,8 +116,6 @@ Summary of documentation-related changes:
0.17
----
**Download** `0.17.0-rc.1`_
+ Allow negative padding in more places
+ Add code generation for some quantized ops
+ Preliminary dynamic shape support
...
...
@@ -131,11 +127,6 @@ Summary of documentation-related changes:
0.16
----
* **Download**: `0.16.0-rc.3`_
* **Download** `0.16.0-rc.2`_
* **Download** `0.16.0-rc.1`_
+ NodeInput and NodeOutput classes prepare for simplifications of Node
+ Test improvements
+ Additional quantization ops
...
...
@@ -143,11 +134,3 @@ Summary of documentation-related changes:
+ Fix memory leak
+ Concat optimization
+ Doc updates
.. _0.20.0-rc.0: https://github.com/NervanaSystems/ngraph/releases/tag/v0.20.0-rc.0_
.. _0.19.0-rc.2: https://github.com/NervanaSystems/ngraph/releases/tag/v0.19.0-rc.2_
.. _0.18.1: https://github.com/NervanaSystems/ngraph/releases/tag/v0.18.1_
.. _0.17.0-rc.1: `https://github.com/NervanaSystems/ngraph/releases/tag/v0.17.0-rc.1
.. _0.16.0-rc.3: https://github.com/NervanaSystems/ngraph/releases/tag/v0.16.0-rc.3
.. _0.16.0-rc.2: https://github.com/NervanaSystems/ngraph/releases/tag/v0.16.0-rc.2
.. _0.16.0-rc.1: https://github.com/NervanaSystems/ngraph/releases/tag/v0.16.0-rc.1
python/test_requirements.txt
View file @
60252edd
pytest
tox
pydocstyle==3.0.0
flake8
flake8-commas
flake8-comprehensions
...
...
src/ngraph/CMakeLists.txt
View file @
60252edd
...
...
@@ -370,7 +370,6 @@ set (SRC
op/util/index_reduction.hpp
op/util/logical_reduction.cpp
op/util/logical_reduction.hpp
op/util/reshape.hpp
op/util/rnn_cell_base.cpp
op/util/rnn_cell_base.hpp
op/util/unary_elementwise_arithmetic.cpp
...
...
src/ngraph/op/util/reshape.hpp
deleted
100644 → 0
View file @
341205cf
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cstddef>
#include <memory>
#include <vector>
#include "ngraph/builder/reshape.hpp"
#include "ngraph/node.hpp"
#include "ngraph/shape.hpp"
namespace
ngraph
{
namespace
op
{
namespace
util
{
/// \brief Change shape of input tensor.
///
/// \param[in] node The node producing the tensor to be reshaped.
/// \param[in] shape The new shape for input tensor.
///
/// \return The node representing a Reshape operation.
///
std
::
shared_ptr
<
ngraph
::
Node
>
reshape
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
node
,
const
Shape
&
shape
)
{
return
builder
::
reshape
(
node
,
shape
);
}
/// \brief Permute axes according to specified axes_order parameter.
///
/// \param node The node which axes we want to permute.
/// \param axes_order The permutation of node tensor axes.
///
/// \return: New node with permuted axes.
std
::
shared_ptr
<
ngraph
::
Node
>
reorder_axes
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
node
,
std
::
vector
<
std
::
size_t
>
axes_order
)
{
return
builder
::
reorder_axes
(
node
,
axes_order
);
}
/// \brief Return transposed tensor (with axes in reversed order).
///
/// \param node Input tensor we want to transpose
///
/// \return: New node with reversed dimensions.
std
::
shared_ptr
<
ngraph
::
Node
>
transpose
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
node
)
{
return
builder
::
transpose
(
node
);
}
/// \brief Flatten the input tensor into a 2D matrix.
///
/// \param node The tensor to be flattened.
/// \param axis The axis dividing shape.
///
/// \return The new node will be a 2D matrix representing the flattened input node.
std
::
shared_ptr
<
ngraph
::
Node
>
flatten
(
const
std
::
shared_ptr
<
ngraph
::
Node
>&
node
,
int
axis
)
{
return
builder
::
flatten
(
node
,
axis
);
}
}
// namespace util
}
// namespace op
}
// namespace ngraph
src/ngraph/runtime/allocator.cpp
View file @
60252edd
...
...
@@ -49,7 +49,8 @@ public:
}
};
std
::
unique_ptr
<
ngraph
::
runtime
::
Allocator
>
ngraph
::
runtime
::
create
_default_allocator
()
ngraph
::
runtime
::
Allocator
*
ngraph
::
runtime
::
get
_default_allocator
()
{
return
std
::
unique_ptr
<
DefaultAllocator
>
(
new
DefaultAllocator
());
static
std
::
unique_ptr
<
DefaultAllocator
>
allocator
(
new
DefaultAllocator
());
return
allocator
.
get
();
}
src/ngraph/runtime/allocator.hpp
View file @
60252edd
...
...
@@ -30,7 +30,7 @@ namespace ngraph
class
DefaultAllocator
;
/// \brief Create a default allocator that calls into system
/// allocation libraries
std
::
unique_ptr
<
Allocator
>
create
_default_allocator
();
ngraph
::
runtime
::
Allocator
*
get
_default_allocator
();
}
}
...
...
src/ngraph/runtime/cpu/cpu_backend.cpp
View file @
60252edd
...
...
@@ -185,7 +185,7 @@ runtime::Allocator* runtime::cpu::CPU_Backend::get_host_memory_allocator()
{
if
(
!
m_allocator
)
{
m_allocator
=
create
_default_allocator
();
return
runtime
::
get
_default_allocator
();
}
return
m_allocator
.
get
();
}
...
...
src/ngraph/runtime/generic_cpu/CMakeLists.txt
View file @
60252edd
...
...
@@ -15,10 +15,10 @@
# ******************************************************************************
if
(
NGRAPH_GENERIC_CPU_ENABLE
)
find_package
(
OpenMP
)
if
(
OPENMP_FOUND
)
add_compile_options
(
${
OpenMP_CXX_FLAGS
}
)
endif
()
#
find_package(OpenMP)
#
if (OPENMP_FOUND)
#
add_compile_options(${OpenMP_CXX_FLAGS})
#
endif()
add_library
(
gcpu_backend SHARED gcpu_backend.cpp gcpu_executable.cpp node_wrapper.cpp
)
if
(
NGRAPH_LIB_VERSIONING_ENABLE
)
set_target_properties
(
gcpu_backend PROPERTIES
...
...
src/ngraph/runtime/generic_cpu/gcpu_backend.cpp
View file @
60252edd
...
...
@@ -52,14 +52,14 @@ runtime::gcpu::GCPUBackend::GCPUBackend(const vector<string>& unsupported_op_nam
shared_ptr
<
runtime
::
Tensor
>
runtime
::
gcpu
::
GCPUBackend
::
create_tensor
(
const
element
::
Type
&
type
,
const
Shape
&
shape
)
{
return
make_shared
<
runtime
::
HostTensor
>
(
type
,
shape
,
this
);
return
make_shared
<
runtime
::
HostTensor
>
(
type
,
shape
);
}
shared_ptr
<
runtime
::
Tensor
>
runtime
::
gcpu
::
GCPUBackend
::
create_tensor
(
const
element
::
Type
&
type
,
const
Shape
&
shape
,
void
*
memory_pointer
)
{
return
make_shared
<
runtime
::
HostTensor
>
(
type
,
shape
,
memory_pointer
,
this
);
return
make_shared
<
runtime
::
HostTensor
>
(
type
,
shape
,
memory_pointer
);
}
shared_ptr
<
runtime
::
Executable
>
...
...
src/ngraph/runtime/generic_cpu/gcpu_executable.cpp
View file @
60252edd
...
...
@@ -15,17 +15,22 @@
//*****************************************************************************
#include "ngraph/runtime/generic_cpu/gcpu_executable.hpp"
#include "ngraph/cpio.hpp"
#include "ngraph/descriptor/layout/dense_tensor_layout.hpp"
#include "ngraph/except.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/util/binary_elementwise_comparison.hpp"
#include "ngraph/pass/assign_layout.hpp"
#include "ngraph/pass/core_fusion.hpp"
#include "ngraph/pass/fused_op_decomposition.hpp"
#include "ngraph/pass/implicit_broadcast_elimination.hpp"
#include "ngraph/pass/like_replacement.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp"
#include "ngraph/runtime/backend_manager.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
using
namespace
std
;
...
...
@@ -35,21 +40,35 @@ using descriptor::layout::DenseTensorLayout;
runtime
::
gcpu
::
GCPUExecutable
::
GCPUExecutable
(
const
shared_ptr
<
Function
>&
function
,
bool
enable_performance_collection
)
:
m_is_compiled
{
true
}
,
m_performance_counters_enabled
{
enable_performance_collection
}
{
{
m_is_compiled
=
true
;
m_function
=
clone_function
(
*
function
);
pass
::
Manager
pass_manager
;
pass_manager
.
register_pass
<
pass
::
LikeReplacement
>
();
pass_manager
.
register_pass
<
pass
::
FusedOpDecomposition
>
();
pass_manager
.
register_pass
<
pass
::
ImplicitBroadcastElimination
>
();
pass_manager
.
register_pass
<
pass
::
AssignLayout
<
DenseTensorLayout
>>
();
pass_manager
.
register_pass
<
pass
::
Liveness
>
();
pass_manager
.
run_passes
(
function
);
pass_manager
.
run_passes
(
m_
function
);
for
(
const
shared_ptr
<
Node
>&
node
:
function
->
get_ordered_ops
())
for
(
const
shared_ptr
<
Node
>&
node
:
m_
function
->
get_ordered_ops
())
{
m_wrapped_nodes
.
emplace_back
(
node
);
}
set_parameters_and_results
(
*
m_function
);
}
runtime
::
gcpu
::
GCPUExecutable
::
GCPUExecutable
(
const
std
::
string
&
model_string
)
:
m_is_compiled
{
true
}
,
m_performance_counters_enabled
{
false
}
{
m_function
=
deserialize
(
model_string
);
for
(
const
shared_ptr
<
Node
>&
node
:
m_function
->
get_ordered_ops
())
{
m_wrapped_nodes
.
emplace_back
(
node
);
}
set_parameters_and_results
(
*
function
);
set_parameters_and_results
(
*
m_
function
);
}
bool
runtime
::
gcpu
::
GCPUExecutable
::
call
(
const
vector
<
shared_ptr
<
runtime
::
Tensor
>>&
outputs
,
...
...
@@ -82,7 +101,7 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
{
for
(
size_t
i
=
0
;
i
<
param
->
get_output_size
();
++
i
)
{
descriptor
::
Tensor
*
tensor
=
param
->
get_output_tensor_ptr
(
i
).
get
();
descriptor
::
Tensor
*
tensor
=
&
param
->
output
(
i
).
get_tensor
();
tensor_map
.
insert
({
tensor
,
func_inputs
[
input_count
++
]});
}
}
...
...
@@ -95,14 +114,14 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
{
throw
ngraph_error
(
"One of function's outputs isn't op::Result"
);
}
descriptor
::
Tensor
*
tensor
=
output
->
get_output_tensor_ptr
(
0
).
get
();
descriptor
::
Tensor
*
tensor
=
&
output
->
output
(
0
).
get_tensor
();
tensor_map
.
insert
({
tensor
,
func_outputs
[
output_count
]});
}
// for each ordered op in the graph
for
(
const
NodeWrapper
&
wrapped
:
m_wrapped_nodes
)
{
const
Node
*
op
=
&
wrapped
.
get_node
();
auto
op
=
wrapped
.
get_node
();
auto
type_id
=
wrapped
.
get_typeid
();
if
(
type_id
==
OP_TYPEID
::
Parameter
)
{
...
...
@@ -111,9 +130,9 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
// get op inputs from map
vector
<
shared_ptr
<
HostTensor
>>
op_inputs
;
for
(
const
descriptor
::
Input
&
input
:
op
->
get_
inputs
())
for
(
auto
input
:
op
->
inputs
())
{
descriptor
::
Tensor
*
tensor
=
input
.
get_output
().
get_tensor_ptr
().
get
();
descriptor
::
Tensor
*
tensor
=
&
input
.
get_tensor
();
op_inputs
.
push_back
(
tensor_map
.
at
(
tensor
));
}
...
...
@@ -121,14 +140,14 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
vector
<
shared_ptr
<
HostTensor
>>
op_outputs
;
for
(
size_t
i
=
0
;
i
<
op
->
get_output_size
();
++
i
)
{
descriptor
::
Tensor
*
tensor
=
op
->
get_output_tensor_ptr
(
i
).
get
();
descriptor
::
Tensor
*
tensor
=
&
op
->
output
(
i
).
get_tensor
();
shared_ptr
<
HostTensor
>
host_tensor
;
auto
it
=
tensor_map
.
find
(
tensor
);
if
(
it
==
tensor_map
.
end
())
{
const
Shape
&
shape
=
op
->
get_output_shape
(
i
);
const
element
::
Type
&
type
=
op
->
get_output_element_type
(
i
);
string
name
=
op
->
get_output_tensor
(
i
).
get_name
();
string
name
=
op
->
output
(
i
).
get_tensor
(
).
get_name
();
host_tensor
=
make_shared
<
runtime
::
HostTensor
>
(
type
,
shape
,
name
);
tensor_map
.
insert
({
tensor
,
host_tensor
});
}
...
...
@@ -177,7 +196,7 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
}
if
(
m_nan_check_enabled
)
{
perform_nan_check
(
op_outputs
,
op
);
perform_nan_check
(
op_outputs
,
op
.
get
()
);
}
}
...
...
@@ -186,19 +205,9 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
void
runtime
::
gcpu
::
GCPUExecutable
::
generate_calls
(
const
element
::
Type
&
type
,
const
NodeWrapper
&
op
,
const
vector
<
shared_ptr
<
HostTensor
>>&
out
puts
,
const
vector
<
shared_ptr
<
HostTensor
>>&
in
puts
)
const
vector
<
shared_ptr
<
HostTensor
>>&
out
,
const
vector
<
shared_ptr
<
HostTensor
>>&
in
)
{
vector
<
void
*>
out
;
vector
<
const
void
*>
in
;
for
(
auto
t
:
outputs
)
{
out
.
push_back
(
t
->
get_data_ptr
());
}
for
(
auto
t
:
inputs
)
{
in
.
push_back
(
t
->
get_data_ptr
());
}
stringstream
ss
;
switch
(
type
.
get_type_enum
())
{
...
...
@@ -216,7 +225,8 @@ void runtime::gcpu::GCPUExecutable::generate_calls(const element::Type& type,
case
element
:
:
Type_t
::
undefined
:
case
element
:
:
Type_t
::
dynamic
:
case
element
:
:
Type_t
::
bf16
:
ss
<<
"unsupported element type "
<<
type
<<
" op "
<<
op
.
get_node
().
get_name
();
case
element
:
:
Type_t
::
f16
:
ss
<<
"unsupported element type "
<<
type
<<
" op "
<<
op
.
get_node
()
->
get_name
();
throw
ngraph_error
(
ss
.
str
());
}
}
...
...
@@ -229,11 +239,9 @@ void runtime::gcpu::GCPUExecutable::set_nan_check(bool enable)
vector
<
runtime
::
PerformanceCounter
>
runtime
::
gcpu
::
GCPUExecutable
::
get_performance_data
()
const
{
vector
<
runtime
::
PerformanceCounter
>
rc
;
for
(
const
pair
<
const
Node
*
,
stopwatch
>
p
:
m_timer_map
)
for
(
const
pair
<
shared_ptr
<
const
Node
>
,
stopwatch
>
p
:
m_timer_map
)
{
rc
.
emplace_back
(
p
.
first
->
get_name
().
c_str
(),
p
.
second
.
get_total_microseconds
(),
p
.
second
.
get_call_count
());
rc
.
emplace_back
(
p
.
first
,
p
.
second
.
get_total_microseconds
(),
p
.
second
.
get_call_count
());
}
return
rc
;
}
...
...
@@ -286,3 +294,12 @@ void runtime::gcpu::GCPUExecutable::perform_nan_check(const vector<shared_ptr<Ho
arg_number
++
;
}
}
void
runtime
::
gcpu
::
GCPUExecutable
::
save
(
ostream
&
out
)
{
cpio
::
Writer
writer
(
out
);
string
si
=
"INTERPRETER Save File 1.0"
;
writer
.
write
(
"save_info"
,
si
.
data
(),
si
.
size
());
string
model
=
serialize
(
m_function
,
0
);
writer
.
write
(
"model"
,
model
.
data
(),
model
.
size
());
}
src/ngraph/runtime/generic_cpu/gcpu_executable.hpp
View file @
60252edd
...
...
@@ -17,24 +17,31 @@
#pragma once
#include <initializer_list>
#include <iostream>
#include <memory>
#include <sstream>
#include <string>
#include <vector>
#include "ngraph/op/all.hpp"
#include "ngraph/op/allreduce.hpp"
#include "ngraph/op/any.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/broadcast_distributed.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/embedding_lookup.hpp"
#include "ngraph/op/experimental/batch_mat_mul.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_pad.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/gather.hpp"
...
...
@@ -48,11 +55,14 @@
#include "ngraph/op/passthrough.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/quantize.hpp"
#include "ngraph/op/quantized_convolution.hpp"
#include "ngraph/op/recv.hpp"
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/send.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/softmax.hpp"
#include "ngraph/op/sum.hpp"
...
...
@@ -64,7 +74,6 @@
#include "ngraph/runtime/generic_cpu/kernel/reshape.hpp"
#include "ngraph/runtime/generic_cpu/node_wrapper.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/interpreter/node_wrapper.hpp"
#include "ngraph/runtime/reference/abs.hpp"
#include "ngraph/runtime/reference/acos.hpp"
#include "ngraph/runtime/reference/add.hpp"
...
...
@@ -77,7 +86,9 @@
#include "ngraph/runtime/reference/asin.hpp"
#include "ngraph/runtime/reference/atan.hpp"
#include "ngraph/runtime/reference/avg_pool.hpp"
#include "ngraph/runtime/reference/batch_mat_mul.hpp"
#include "ngraph/runtime/reference/batch_norm.hpp"
#include "ngraph/runtime/reference/broadcast.hpp"
#include "ngraph/runtime/reference/broadcast_distributed.hpp"
#include "ngraph/runtime/reference/ceiling.hpp"
#include "ngraph/runtime/reference/concat.hpp"
...
...
@@ -89,8 +100,10 @@
#include "ngraph/runtime/reference/cosh.hpp"
#include "ngraph/runtime/reference/dequantize.hpp"
#include "ngraph/runtime/reference/divide.hpp"
#include "ngraph/runtime/reference/dot.hpp"
#include "ngraph/runtime/reference/embedding_lookup.hpp"
#include "ngraph/runtime/reference/equal.hpp"
#include "ngraph/runtime/reference/erf.hpp"
#include "ngraph/runtime/reference/exp.hpp"
#include "ngraph/runtime/reference/floor.hpp"
#include "ngraph/runtime/reference/gather.hpp"
...
...
@@ -117,14 +130,17 @@
#include "ngraph/runtime/reference/power.hpp"
#include "ngraph/runtime/reference/product.hpp"
#include "ngraph/runtime/reference/quantize.hpp"
#include "ngraph/runtime/reference/recv.hpp"
#include "ngraph/runtime/reference/relu.hpp"
#include "ngraph/runtime/reference/replace_slice.hpp"
#include "ngraph/runtime/reference/reshape.hpp"
#include "ngraph/runtime/reference/result.hpp"
#include "ngraph/runtime/reference/reverse.hpp"
#include "ngraph/runtime/reference/reverse_sequence.hpp"
#include "ngraph/runtime/reference/scatter_add.hpp"
#include "ngraph/runtime/reference/scatter_nd_add.hpp"
#include "ngraph/runtime/reference/select.hpp"
#include "ngraph/runtime/reference/send.hpp"
#include "ngraph/runtime/reference/shape_of.hpp"
#include "ngraph/runtime/reference/sigmoid.hpp"
#include "ngraph/runtime/reference/sign.hpp"
...
...
@@ -134,6 +150,7 @@
#include "ngraph/runtime/reference/softmax.hpp"
#include "ngraph/runtime/reference/sqrt.hpp"
#include "ngraph/runtime/reference/subtract.hpp"
#include "ngraph/runtime/reference/sum.hpp"
#include "ngraph/runtime/reference/tan.hpp"
#include "ngraph/runtime/reference/tanh.hpp"
#include "ngraph/runtime/reference/topk.hpp"
...
...
@@ -154,6 +171,8 @@ namespace ngraph
class
ngraph
::
runtime
::
gcpu
::
GCPUExecutable
:
public
Executable
{
friend
class
GCPUBackend
;
public
:
GCPUExecutable
(
const
std
::
shared_ptr
<
Function
>&
function
,
bool
enable_performance_collection
=
false
);
...
...
@@ -161,20 +180,25 @@ public:
bool
call
(
const
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>&
outputs
,
const
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>&
intputs
)
override
;
virtual
void
save
(
std
::
ostream
&
output_stream
)
override
;
void
set_nan_check
(
bool
enable
);
std
::
vector
<
PerformanceCounter
>
get_performance_data
()
const
override
;
private
:
GCPUExecutable
(
const
std
::
string
&
model_string
);
int
get_alignment
()
const
{
return
64
;
}
bool
m_is_compiled
=
false
;
bool
m_nan_check_enabled
=
false
;
bool
m_performance_counters_enabled
=
false
;
std
::
unordered_map
<
const
Node
*
,
stopwatch
>
m_timer_map
;
std
::
shared_ptr
<
Function
>
m_function
;
std
::
unordered_map
<
std
::
shared_ptr
<
const
Node
>
,
stopwatch
>
m_timer_map
;
std
::
vector
<
NodeWrapper
>
m_wrapped_nodes
;
std
::
unordered_map
<
const
Node
*
,
std
::
shared_ptr
<
RNGState
>>
m_states
;
std
::
set
<
std
::
string
>
m_unsupported_op_name_list
;
int
get_alignment
()
const
{
return
64
;
}
static
void
perform_nan_check
(
const
std
::
vector
<
std
::
shared_ptr
<
HostTensor
>>&
,
const
Node
*
op
=
nullptr
);
...
...
@@ -185,11 +209,10 @@ private:
template
<
typename
T
>
void
op_engine
(
const
NodeWrapper
&
node_wrapper
,
const
std
::
vector
<
void
*
>&
out
,
const
std
::
vector
<
const
void
*
>&
args
)
const
std
::
vector
<
std
::
shared_ptr
<
HostTensor
>
>&
out
,
const
std
::
vector
<
std
::
shared_ptr
<
HostTensor
>
>&
args
)
{
const
Node
&
node
=
node_wrapper
.
get_node
();
std
::
string
node_op
=
node
.
description
();
const
Node
&
node
=
*
node_wrapper
.
get_node
();
// We want to check that every OP_TYPEID enumeration is included in the list.
// These GCC flags enable compile-time checking so that if an enumeration
...
...
@@ -206,30 +229,30 @@ private:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
abs
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Acos
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
acos
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Add
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
add
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
add
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
All
:
{
const
op
::
All
*
all
=
static_cast
<
const
op
::
All
*>
(
&
node
);
reference
::
all
(
static_cast
<
const
char
*>
(
args
[
0
]
),
static_cast
<
char
*>
(
out
[
0
]
),
reference
::
all
(
args
[
0
]
->
get_data_ptr
<
const
char
>
(
),
out
[
0
]
->
get_data_ptr
<
char
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
all
->
get_reduction_axes
());
...
...
@@ -237,26 +260,29 @@ private:
}
case
OP_TYPEID
:
:
AllReduce
:
{
reference
::
allreduce
<
T
>
(
static_cast
<
T
*>
(
const_cast
<
void
*>
(
args
[
0
])),
static_cast
<
T
*>
(
out
[
0
]),
node
.
get_input_element_type
(
0
),
const
ngraph
::
op
::
AllReduce
*
allreduce
=
static_cast
<
const
ngraph
::
op
::
AllReduce
*>
(
&
node
);
reference
::
allreduce
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
node
.
get_input_element_type
(
0
).
get_type_enum
(),
allreduce
->
get_reduce_type
(),
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
))));
break
;
}
case
OP_TYPEID
:
:
And
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
logical_and
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
logical_and
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Any
:
{
const
op
::
Any
*
any
=
static_cast
<
const
op
::
Any
*>
(
&
node
);
reference
::
any
(
static_cast
<
const
char
*>
(
args
[
0
]
),
static_cast
<
char
*>
(
out
[
0
]
),
reference
::
any
(
args
[
0
]
->
get_data_ptr
<
const
char
>
(
),
out
[
0
]
->
get_data_ptr
<
char
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
any
->
get_reduction_axes
());
...
...
@@ -268,16 +294,16 @@ private:
auto
element_type
=
node
.
get_output_element_type
(
0
);
if
(
element_type
==
element
::
i64
)
{
reference
::
argmin
<
T
,
int64_t
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
int64_t
*>
(
out
[
0
]
),
reference
::
argmin
<
T
,
int64_t
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
int64_t
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
argmin
->
get_reduction_axis
());
}
else
if
(
element_type
==
element
::
i32
)
{
reference
::
argmin
<
T
,
int32_t
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
int32_t
*>
(
out
[
0
]
),
reference
::
argmin
<
T
,
int32_t
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
int32_t
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
argmin
->
get_reduction_axis
());
...
...
@@ -294,16 +320,16 @@ private:
auto
element_type
=
node
.
get_output_element_type
(
0
);
if
(
element_type
==
element
::
i64
)
{
reference
::
argmax
<
T
,
int64_t
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
int64_t
*>
(
out
[
0
]
),
reference
::
argmax
<
T
,
int64_t
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
int64_t
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
argmax
->
get_reduction_axis
());
}
else
if
(
element_type
==
element
::
i32
)
{
reference
::
argmax
<
T
,
int32_t
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
int32_t
*>
(
out
[
0
]
),
reference
::
argmax
<
T
,
int32_t
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
int32_t
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
argmax
->
get_reduction_axis
());
...
...
@@ -318,22 +344,22 @@ private:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
asin
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Atan
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
atan
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
AvgPool
:
{
const
op
::
AvgPool
*
avg_pool
=
static_cast
<
const
op
::
AvgPool
*>
(
&
node
);
reference
::
avg_pool
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
avg_pool
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
avg_pool
->
get_window_shape
(),
...
...
@@ -345,18 +371,30 @@ private:
}
case
OP_TYPEID
:
:
GenerateMask
:
{
bool
use_seed
=
static_cast
<
bool
>
(
args
[
2
]
->
get_data_ptr
<
const
int32_t
>
()[
0
]);
if
(
m_states
.
count
(
&
node
)
==
0
)
{
const
op
::
GenerateMask
*
gm
=
static_cast
<
const
op
::
GenerateMask
*>
(
&
node
);
auto
seed
=
use_seed
?
gm
->
get_seed
()
:
0
;
m_states
[
&
node
]
=
std
::
unique_ptr
<
ngraph
::
RNGState
>
(
ngraph
::
RNGState
::
create_rng_state
(
gm
->
get_seed
()
,
gm
->
get_probability
()));
ngraph
::
RNGState
::
create_rng_state
(
seed
,
gm
->
get_probability
()));
}
bool
training
=
static_cast
<
bool
>
(
static_cast
<
const
T
*>
(
args
[
0
]
)[
0
]);
bool
training
=
static_cast
<
bool
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
)[
0
]);
auto
state
=
m_states
.
at
(
&
node
).
get
();
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
if
(
!
use_seed
)
{
reference
::
generate_mask
<
T
>
(
reinterpret_cast
<
T
*>
(
out
[
0
]),
element_count
,
state
,
training
);
out
[
0
]
->
get_data_ptr
<
T
>
(),
element_count
,
state
,
training
);
}
else
{
uint64_t
seed
=
static_cast
<
uint64_t
>
(
args
[
3
]
->
get_data_ptr
<
const
T
>
()[
0
]);
double
prob
=
static_cast
<
double
>
(
args
[
4
]
->
get_data_ptr
<
const
T
>
()[
0
]);
reference
::
generate_mask_no_state
<
T
>
(
out
[
0
]
->
get_data_ptr
<
T
>
(),
element_count
,
training
,
seed
,
prob
);
}
break
;
}
case
OP_TYPEID
:
:
GetOutputElement
:
...
...
@@ -366,20 +404,31 @@ private:
size_t
n
=
get_output_element
->
get_n
();
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
num_bytes
=
element_count
*
node
.
get_output_element_type
(
0
).
size
();
std
::
memcpy
(
static_cast
<
T
*>
(
out
[
0
]),
args
[
n
],
num_bytes
);
std
::
memcpy
(
out
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
n
]
->
get_data_ptr
<
T
>
(),
num_bytes
);
break
;
}
case
OP_TYPEID
:
:
BatchMatMul
:
{
reference
::
batch_mat_mul
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
args
[
1
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_output_shape
(
0
));
break
;
}
case
OP_TYPEID
:
:
BatchNormTraining
:
{
const
ngraph
::
op
::
BatchNormTraining
*
bn
=
static_cast
<
const
ngraph
::
op
::
BatchNormTraining
*>
(
&
node
);
reference
::
batch_norm_training
<
T
>
(
bn
->
get_eps_value
(),
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
const
T
*>
(
args
[
2
]
),
static_cast
<
T
*>
(
out
[
0
]
),
static_cast
<
T
*>
(
out
[
1
]
),
static_cast
<
T
*>
(
out
[
2
]
),
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
args
[
2
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
out
[
1
]
->
get_data_ptr
<
T
>
(
),
out
[
2
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
2
));
break
;
}
...
...
@@ -388,12 +437,12 @@ private:
const
ngraph
::
op
::
BatchNormInference
*
bn
=
static_cast
<
const
ngraph
::
op
::
BatchNormInference
*>
(
&
node
);
reference
::
batch_norm_inference
<
T
>
(
bn
->
get_eps_value
(),
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
const
T
*>
(
args
[
2
]
),
static_cast
<
const
T
*>
(
args
[
3
]
),
static_cast
<
const
T
*>
(
args
[
4
]
),
static_cast
<
T
*>
(
out
[
0
]
),
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
args
[
2
]
->
get_data_ptr
<
const
T
>
(
),
args
[
3
]
->
get_data_ptr
<
const
T
>
(
),
args
[
4
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
2
));
break
;
}
...
...
@@ -402,23 +451,23 @@ private:
const
ngraph
::
op
::
BatchNormTrainingBackprop
*
bn_bprop
=
static_cast
<
const
ngraph
::
op
::
BatchNormTrainingBackprop
*>
(
&
node
);
reference
::
batch_norm_backprop
(
bn_bprop
->
get_eps_value
(),
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
const
T
*>
(
args
[
2
]
),
static_cast
<
const
T
*>
(
args
[
3
]
),
static_cast
<
const
T
*>
(
args
[
4
]
),
static_cast
<
const
T
*>
(
args
[
5
]
),
static_cast
<
T
*>
(
out
[
0
]
),
static_cast
<
T
*>
(
out
[
1
]
),
static_cast
<
T
*>
(
out
[
2
]
),
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
args
[
2
]
->
get_data_ptr
<
const
T
>
(
),
args
[
3
]
->
get_data_ptr
<
const
T
>
(
),
args
[
4
]
->
get_data_ptr
<
const
T
>
(
),
args
[
5
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
out
[
1
]
->
get_data_ptr
<
T
>
(
),
out
[
2
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
2
));
break
;
}
case
OP_TYPEID
:
:
AvgPoolBackprop
:
{
const
op
::
AvgPoolBackprop
*
apb
=
static_cast
<
const
op
::
AvgPoolBackprop
*>
(
&
node
);
reference
::
avg_pool_backprop
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
avg_pool_backprop
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
apb
->
get_window_shape
(),
...
...
@@ -434,8 +483,8 @@ private:
Shape
in_shape
=
node
.
get_input_shape
(
0
);
Shape
out_shape
=
node
.
get_output_shape
(
0
);
AxisSet
broadcast_axes
=
broadcast
->
get_broadcast_axes
();
gcpu
::
kernel
::
broadcast
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
kernel
::
broadcast
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
in_shape
,
out_shape
,
broadcast_axes
);
...
...
@@ -443,23 +492,28 @@ private:
}
case
OP_TYPEID
:
:
BroadcastDistributed
:
{
int
rank_ID
=
get_distributed_interface
()
->
get_rank
();
if
(
rank_ID
==
0
)
const
ngraph
::
op
::
BroadcastDistributed
*
broadcast
=
static_cast
<
const
ngraph
::
op
::
BroadcastDistributed
*>
(
&
node
);
int
rank_ID
;
rank_ID
=
get_distributed_interface
()
->
get_rank
();
int
root_id
=
broadcast
->
get_root_id
();
if
(
rank_ID
==
root_id
)
{
reference
::
broadcastdistributed
<
T
>
(
static_cast
<
T
*>
(
args
[
0
]
),
node
.
get_input_element_type
(
0
),
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
)))
);
auto
memSize
=
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
)))
*
sizeof
(
node
.
get_input_element_type
(
0
)
);
memcpy
(
out
[
0
]
,
args
[
0
]
,
memSize
);
args
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_element_type
(
0
)
.
get_type_enum
()
,
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
)))
,
root_id
);
auto
memSize
=
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
)))
*
sizeof
(
T
);
memcpy
(
out
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
0
]
->
get_data_ptr
<
T
>
()
,
memSize
);
}
else
{
reference
::
broadcastdistributed
<
T
>
(
static_cast
<
T
*>
(
out
[
0
]),
node
.
get_input_element_type
(
0
),
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
))));
out
[
0
]
->
get_data_ptr
<
T
>
(),
node
.
get_input_element_type
(
0
).
get_type_enum
(),
static_cast
<
int
>
(
shape_size
(
node
.
get_input_shape
(
0
))),
root_id
);
}
break
;
}
...
...
@@ -468,7 +522,7 @@ private:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
ceiling
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Concat
:
...
...
@@ -478,11 +532,11 @@ private:
std
::
vector
<
Shape
>
in_shapes
;
for
(
size_t
i
=
0
;
i
<
node
.
get_input_size
();
i
++
)
{
in_args
.
push_back
(
static_cast
<
const
T
*>
(
args
[
i
]
));
in_args
.
push_back
(
args
[
i
]
->
get_data_ptr
<
const
T
>
(
));
in_shapes
.
push_back
(
node
.
get_input_shape
(
i
));
}
reference
::
concat
<
T
>
(
in_args
,
static_cast
<
T
*>
(
out
[
0
]
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
in_shapes
,
node
.
get_output_shape
(
0
),
concat
->
get_concatenation_axis
());
...
...
@@ -492,7 +546,7 @@ private:
{
const
op
::
Constant
*
c
=
static_cast
<
const
op
::
Constant
*>
(
&
node
);
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
constant
<
T
>
(
c
->
get_data_ptr
<
T
>
(),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
reference
::
constant
<
T
>
(
c
->
get_data_ptr
<
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
ScalarConstantLike
:
break
;
...
...
@@ -505,52 +559,62 @@ private:
switch
(
type
.
get_type_enum
())
{
case
element
:
:
Type_t
::
boolean
:
reference
::
convert
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
char
*>
(
out
[
0
]
),
element_count
);
reference
::
convert
_to_bool
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
char
>
(
),
element_count
);
break
;
case
element
:
:
Type_t
::
f32
:
reference
::
convert
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
float
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
float
>
(
),
element_count
);
break
;
case
element
:
:
Type_t
::
f64
:
reference
::
convert
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
double
*>
(
out
[
0
]),
element_count
);
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
double
>
(),
element_count
);
break
;
case
element
:
:
Type_t
::
i8
:
reference
::
convert
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
int8_t
*>
(
out
[
0
]),
element_count
);
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
int8_t
>
(),
element_count
);
break
;
case
element
:
:
Type_t
::
i16
:
reference
::
convert
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
int16_t
*>
(
out
[
0
]),
element_count
);
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
int16_t
>
(),
element_count
);
break
;
case
element
:
:
Type_t
::
i32
:
reference
::
convert
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
int32_t
*>
(
out
[
0
]),
element_count
);
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
int32_t
>
(),
element_count
);
break
;
case
element
:
:
Type_t
::
i64
:
reference
::
convert
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
int64_t
*>
(
out
[
0
]),
element_count
);
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
int64_t
>
(),
element_count
);
break
;
case
element
:
:
Type_t
::
u8
:
reference
::
convert
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
uint8_t
*>
(
out
[
0
]),
element_count
);
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
uint8_t
>
(),
element_count
);
break
;
case
element
:
:
Type_t
::
u16
:
reference
::
convert
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
uint16_t
*>
(
out
[
0
]),
element_count
);
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
uint16_t
>
(),
element_count
);
break
;
case
element
:
:
Type_t
::
u32
:
reference
::
convert
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
uint32_t
*>
(
out
[
0
]),
element_count
);
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
uint32_t
>
(),
element_count
);
break
;
case
element
:
:
Type_t
::
u64
:
reference
::
convert
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
uint64_t
*>
(
out
[
0
]),
element_count
);
reference
::
convert
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
uint64_t
>
(),
element_count
);
break
;
case
element
:
:
Type_t
::
undefined
:
case
element
:
:
Type_t
::
dynamic
:
case
element
:
:
Type_t
::
bf16
:
case
element
:
:
Type_t
::
f16
:
ss
<<
"unsupported element type "
<<
type
<<
" op Convert"
;
throw
std
::
runtime_error
(
ss
.
str
());
}
...
...
@@ -559,9 +623,9 @@ private:
case
OP_TYPEID
:
:
Convolution
:
{
const
op
::
Convolution
*
c
=
static_cast
<
const
op
::
Convolution
*>
(
&
node
);
reference
::
convolution
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
convolution
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_output_shape
(
0
),
...
...
@@ -569,38 +633,26 @@ private:
c
->
get_window_dilation_strides
(),
c
->
get_padding_below
(),
c
->
get_padding_above
(),
c
->
get_data_dilation_strides
(),
0
,
1
,
1
,
0
,
0
,
1
,
false
);
c
->
get_data_dilation_strides
());
break
;
}
case
OP_TYPEID
:
:
ConvolutionBackpropFilters
:
{
const
op
::
ConvolutionBackpropFilters
*
c
=
static_cast
<
const
op
::
ConvolutionBackpropFilters
*>
(
&
node
);
reference
::
convolution
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
const
T
*>
(
args
[
1
]),
static_cast
<
T
*>
(
out
[
0
]),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_output_shape
(
0
),
c
->
get_window_movement_strides_backward
(),
c
->
get_window_dilation_strides_backward
(),
c
->
get_padding_below_backward
(),
c
->
get_padding_above_backward
(),
c
->
get_data_dilation_strides_backward
(),
1
,
0
,
0
,
1
,
1
,
0
,
false
);
reference
::
convolution_backprop_filter
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
// input
args
[
1
]
->
get_data_ptr
<
const
T
>
(),
// delta_convolution_output
out
[
0
]
->
get_data_ptr
<
T
>
(),
// delta_filter
c
->
get_input_shape
(
0
),
// input_shape
c
->
get_input_shape
(
1
),
// convolution_output_shape
c
->
get_filters_shape
(),
// filter_shape
c
->
get_window_dilation_strides_forward
(),
c
->
get_window_movement_strides_forward
(),
c
->
get_padding_below_forward
(),
c
->
compute_backward_in_pad_above
(),
c
->
get_data_dilation_strides_forward
());
break
;
}
case
OP_TYPEID
:
:
ConvolutionBackpropData
:
...
...
@@ -608,38 +660,31 @@ private:
// Note that args[1] and args[0] are switched here from the usual order.
const
op
::
ConvolutionBackpropData
*
c
=
static_cast
<
const
op
::
ConvolutionBackpropData
*>
(
&
node
);
reference
::
convolution
<
T
>
(
static_cast
<
const
T
*>
(
args
[
1
]),
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]),
node
.
get_input_shape
(
1
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
c
->
get_window_movement_strides_backward
(),
c
->
get_window_dilation_strides_backward
(),
c
->
get_padding_below_backward
(),
c
->
get_padding_above_backward
(),
c
->
get_data_dilation_strides_backward
(),
0
,
1
,
0
,
1
,
0
,
1
,
true
);
reference
::
convolution_backprop_in
<
T
>
(
args
[
1
]
->
get_data_ptr
<
const
T
>
(),
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
c
->
get_input_shape
(
1
),
c
->
get_input_shape
(
0
),
c
->
get_data_batch_shape
(),
c
->
get_data_dilation_strides_forward
(),
c
->
get_window_dilation_strides_forward
(),
c
->
compute_backward_delta_out_pad_below
(),
c
->
compute_backward_delta_out_pad_above
(),
c
->
get_window_movement_strides_forward
());
break
;
}
case
OP_TYPEID
:
:
Cos
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
cos
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Cosh
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
cosh
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Dequantize
:
...
...
@@ -649,20 +694,20 @@ private:
if
(
type
==
element
::
f32
)
{
reference
::
dequantize
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
float
*>
(
args
[
1
]
),
static_cast
<
const
T
*>
(
args
[
2
]
),
static_cast
<
float
*>
(
out
[
0
]
),
reference
::
dequantize
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
float
>
(
),
args
[
2
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
float
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
dequantize
->
get_axes
());
}
else
if
(
type
==
element
::
f64
)
{
reference
::
dequantize
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
double
*>
(
args
[
1
]
),
static_cast
<
const
T
*>
(
args
[
2
]
),
static_cast
<
double
*>
(
out
[
0
]
),
reference
::
dequantize
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
double
>
(
),
args
[
2
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
double
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
dequantize
->
get_axes
());
...
...
@@ -680,9 +725,9 @@ private:
{
const
op
::
Divide
*
divop
=
static_cast
<
const
op
::
Divide
*>
(
&
node
);
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
divide
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
divide
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
,
divop
->
is_pythondiv
());
break
;
...
...
@@ -691,15 +736,25 @@ private:
{
const
op
::
Dot
*
dot
=
static_cast
<
const
op
::
Dot
*>
(
&
node
);
gcpu
::
kernel
::
dot
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
kernel
::
dot
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_output_shape
(
0
),
dot
->
get_reduction_axes_count
());
break
;
}
case
OP_TYPEID
:
:
DynReshape
:
{
throw
unsupported_op
(
"Unsupported op '"
+
node
.
description
()
+
"'"
);
break
;
}
case
OP_TYPEID
:
:
DynSlice
:
{
throw
unsupported_op
(
"Unsupported op '"
+
node
.
description
()
+
"'"
);
break
;
}
case
OP_TYPEID
:
:
EmbeddingLookup
:
{
const
op
::
EmbeddingLookup
*
embed
=
static_cast
<
const
op
::
EmbeddingLookup
*>
(
&
node
);
...
...
@@ -708,33 +763,33 @@ private:
if
(
type
==
element
::
f32
)
{
reference
::
embedding
<
T
,
float
>
(
static_cast
<
const
float
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
embedding
<
T
,
float
>
(
args
[
0
]
->
get_data_ptr
<
const
float
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
,
embed
->
get_shape
());
}
else
if
(
type
==
element
::
f64
)
{
reference
::
embedding
<
T
,
double
>
(
static_cast
<
const
double
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
embedding
<
T
,
double
>
(
args
[
0
]
->
get_data_ptr
<
const
double
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
,
embed
->
get_shape
());
}
else
if
(
type
==
element
::
i32
)
{
reference
::
embedding
<
T
,
int
>
(
static_cast
<
const
int
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
embedding
<
T
,
int
32_t
>
(
args
[
0
]
->
get_data_ptr
<
const
int
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
,
embed
->
get_shape
());
}
else
if
(
type
==
element
::
i64
)
{
reference
::
embedding
<
T
,
int64_t
>
(
static_cast
<
const
int64_t
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
embedding
<
T
,
int64_t
>
(
args
[
0
]
->
get_data_ptr
<
const
int64_t
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
,
embed
->
get_shape
());
}
...
...
@@ -748,24 +803,56 @@ private:
case
OP_TYPEID
:
:
Equal
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
equal
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
char
*>
(
out
[
0
]
),
reference
::
equal
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
char
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Erf
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
erf
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Exp
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
exp
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(),
element_count
);
break
;
}
#ifdef INTERPRETER_USE_HYBRID
case
OP_TYPEID
:
:
FunctionCall
:
{
auto
f
=
static_cast
<
const
runtime
::
hybrid
::
op
::
FunctionCall
*>
(
&
node
);
auto
backend
=
f
->
get_backend
();
auto
executable
=
f
->
get_executable
();
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
outputs
;
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
inputs
;
for
(
const
std
::
shared_ptr
<
HostTensor
>&
t
:
out
)
{
auto
backend_tensor
=
backend
->
create_tensor
(
t
->
get_element_type
(),
t
->
get_shape
(),
t
->
get_data_ptr
());
outputs
.
push_back
(
backend_tensor
);
}
for
(
const
std
::
shared_ptr
<
HostTensor
>&
t
:
args
)
{
auto
backend_tensor
=
backend
->
create_tensor
(
t
->
get_element_type
(),
t
->
get_shape
(),
t
->
get_data_ptr
());
inputs
.
push_back
(
backend_tensor
);
}
executable
->
call
(
outputs
,
inputs
);
break
;
}
#endif
case
OP_TYPEID
:
:
Floor
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
floor
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Gather
:
...
...
@@ -826,36 +913,36 @@ private:
case
OP_TYPEID
:
:
Greater
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
greater
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
char
*>
(
out
[
0
]
),
reference
::
greater
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
char
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
GreaterEq
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
greater_eq
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
char
*>
(
out
[
0
]
),
reference
::
greater_eq
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
char
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Less
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
less
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
char
*>
(
out
[
0
]
),
reference
::
less
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
char
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
LessEq
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
less_eq
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
char
*>
(
out
[
0
]
),
reference
::
less_eq
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
char
>
(
),
element_count
);
break
;
}
...
...
@@ -863,14 +950,14 @@ private:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
log
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
LRN
:
{
const
op
::
LRN
*
lrn
=
static_cast
<
const
op
::
LRN
*>
(
&
node
);
reference
::
lrn
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
lrn
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
lrn
->
get_alpha
(),
lrn
->
get_beta
(),
...
...
@@ -881,8 +968,8 @@ private:
case
OP_TYPEID
:
:
Max
:
{
const
op
::
Max
*
max
=
static_cast
<
const
op
::
Max
*>
(
&
node
);
reference
::
max
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
max
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
max
->
get_reduction_axes
());
...
...
@@ -891,9 +978,9 @@ private:
case
OP_TYPEID
:
:
Maximum
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
maximum
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
maximum
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
...
...
@@ -901,8 +988,8 @@ private:
{
const
op
::
MaxPool
*
max_pool
=
static_cast
<
const
op
::
MaxPool
*>
(
&
node
);
reference
::
max_pool
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
max_pool
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
max_pool
->
get_window_shape
(),
...
...
@@ -916,9 +1003,9 @@ private:
const
op
::
MaxPoolBackprop
*
max_pool_backprop
=
static_cast
<
const
op
::
MaxPoolBackprop
*>
(
&
node
);
reference
::
max_pool_backprop
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
max_pool_backprop
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
1
),
node
.
get_output_shape
(
0
),
max_pool_backprop
->
get_window_shape
(),
...
...
@@ -930,8 +1017,8 @@ private:
case
OP_TYPEID
:
:
Min
:
{
const
op
::
Min
*
min
=
static_cast
<
const
op
::
Min
*>
(
&
node
);
reference
::
min
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
min
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
min
->
get_reduction_axes
());
...
...
@@ -940,18 +1027,18 @@ private:
case
OP_TYPEID
:
:
Minimum
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
minimum
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
minimum
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Multiply
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
multiply
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
multiply
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
...
...
@@ -959,30 +1046,30 @@ private:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
negate
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Not
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
logical_not
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
NotEqual
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
not_equal
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
char
*>
(
out
[
0
]
),
reference
::
not_equal
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
char
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
OneHot
:
{
const
op
::
OneHot
*
oh
=
static_cast
<
const
op
::
OneHot
*>
(
&
node
);
reference
::
one_hot
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
one_hot
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
oh
->
get_one_hot_axis
());
...
...
@@ -991,46 +1078,46 @@ private:
case
OP_TYPEID
:
:
Or
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
logical_or
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
logical_or
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Parameter
:
break
;
case
OP_TYPEID
:
:
Passthrough
:
{
const
op
::
Passthrough
*
passthrough
=
static_cast
<
const
op
::
Passthrough
*>
(
&
node
);
throw
unsupported_op
{
"Unsupported operation language: "
+
passthrough
->
language
()};
}
case
OP_TYPEID
:
:
Pad
:
{
const
op
::
Pad
*
pad
=
static_cast
<
const
op
::
Pad
*>
(
&
node
);
reference
::
pad
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
node
.
get_inputs
().
a
t
(
0
).
get_shape
(),
node
.
get_output_shape
(
0
),
reference
::
pad
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
inpu
t
(
0
).
get_shape
(),
node
.
output
(
0
).
get_shape
(
),
pad
->
get_padding_below
(),
pad
->
get_padding_above
(),
pad
->
get_pad
ding_interior
());
pad
->
get_pad
_mode
());
break
;
}
case
OP_TYPEID
:
:
Passthrough
:
{
const
op
::
Passthrough
*
passthrough
=
static_cast
<
const
op
::
Passthrough
*>
(
&
node
);
throw
unsupported_op
{
"Unsupported operation language: "
+
passthrough
->
language
()};
}
case
OP_TYPEID
:
:
Power
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
power
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
power
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Product
:
{
const
op
::
Product
*
product
=
static_cast
<
const
op
::
Product
*>
(
&
node
);
reference
::
product
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
product
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
product
->
get_reduction_axes
());
...
...
@@ -1043,10 +1130,10 @@ private:
if
(
type
==
element
::
u8
)
{
reference
::
quantize
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
const
uint8_t
*>
(
args
[
2
]
),
static_cast
<
uint8_t
*>
(
out
[
0
]
),
reference
::
quantize
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
args
[
2
]
->
get_data_ptr
<
const
uint8_t
>
(
),
out
[
0
]
->
get_data_ptr
<
uint8_t
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
quantize
->
get_axes
(),
...
...
@@ -1054,10 +1141,10 @@ private:
}
else
if
(
type
==
element
::
i8
)
{
reference
::
quantize
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
const
int8_t
*>
(
args
[
2
]
),
static_cast
<
int8_t
*>
(
out
[
0
]
),
reference
::
quantize
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
args
[
2
]
->
get_data_ptr
<
const
int8_t
>
(
),
out
[
0
]
->
get_data_ptr
<
int8_t
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
quantize
->
get_axes
(),
...
...
@@ -1065,10 +1152,10 @@ private:
}
else
if
(
type
==
element
::
i32
)
{
reference
::
quantize
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
const
int32_t
*>
(
args
[
2
]
),
static_cast
<
int32_t
*>
(
out
[
0
]
),
reference
::
quantize
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
args
[
2
]
->
get_data_ptr
<
const
int32_t
>
(
),
out
[
0
]
->
get_data_ptr
<
int32_t
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
quantize
->
get_axes
(),
...
...
@@ -1083,40 +1170,168 @@ private:
break
;
}
case
OP_TYPEID
:
:
QuantizedConvolution
:
{
const
op
::
QuantizedConvolution
*
qc
=
static_cast
<
const
op
::
QuantizedConvolution
*>
(
&
node
);
auto
input_element_type
=
qc
->
get_input_element_type
(
0
);
auto
filter_element_type
=
qc
->
get_input_element_type
(
1
);
auto
output_element_type
=
qc
->
get_output_element_type
(
0
);
if
(
input_element_type
==
element
::
u8
&&
filter_element_type
==
element
::
i8
&&
output_element_type
==
element
::
i8
)
{
reference
::
convolution
<
uint8_t
,
int8_t
,
int8_t
,
int32_t
>
(
args
[
0
]
->
get_data_ptr
<
const
uint8_t
>
(),
args
[
1
]
->
get_data_ptr
<
const
int8_t
>
(),
out
[
0
]
->
get_data_ptr
<
int8_t
>
(),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_output_shape
(
0
),
qc
->
get_window_movement_strides
(),
qc
->
get_window_dilation_strides
(),
qc
->
get_padding_below
(),
qc
->
get_padding_above
(),
qc
->
get_data_dilation_strides
(),
args
[
2
]
->
get_data_ptr
<
const
float
>
(),
args
[
3
]
->
get_data_ptr
<
const
uint8_t
>
(),
args
[
4
]
->
get_data_ptr
<
const
float
>
(),
args
[
5
]
->
get_data_ptr
<
const
int8_t
>
(),
args
[
6
]
->
get_data_ptr
<
const
float
>
(),
args
[
7
]
->
get_data_ptr
<
const
int8_t
>
());
}
else
if
(
input_element_type
==
element
::
u8
&&
filter_element_type
==
element
::
u8
&&
output_element_type
==
element
::
u8
)
{
reference
::
convolution
<
uint8_t
,
uint8_t
,
uint8_t
,
int32_t
>
(
args
[
0
]
->
get_data_ptr
<
const
uint8_t
>
(),
args
[
1
]
->
get_data_ptr
<
const
uint8_t
>
(),
out
[
0
]
->
get_data_ptr
<
uint8_t
>
(),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_output_shape
(
0
),
qc
->
get_window_movement_strides
(),
qc
->
get_window_dilation_strides
(),
qc
->
get_padding_below
(),
qc
->
get_padding_above
(),
qc
->
get_data_dilation_strides
(),
args
[
2
]
->
get_data_ptr
<
const
float
>
(),
args
[
3
]
->
get_data_ptr
<
const
uint8_t
>
(),
args
[
4
]
->
get_data_ptr
<
const
float
>
(),
args
[
5
]
->
get_data_ptr
<
const
uint8_t
>
(),
args
[
6
]
->
get_data_ptr
<
const
float
>
(),
args
[
7
]
->
get_data_ptr
<
const
uint8_t
>
());
}
else
if
(
input_element_type
==
element
::
u8
&&
filter_element_type
==
element
::
i8
&&
output_element_type
==
element
::
i32
)
{
reference
::
convolution
<
uint8_t
,
int8_t
,
int32_t
,
int32_t
>
(
args
[
0
]
->
get_data_ptr
<
const
uint8_t
>
(),
args
[
1
]
->
get_data_ptr
<
const
int8_t
>
(),
out
[
0
]
->
get_data_ptr
<
int32_t
>
(),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_output_shape
(
0
),
qc
->
get_window_movement_strides
(),
qc
->
get_window_dilation_strides
(),
qc
->
get_padding_below
(),
qc
->
get_padding_above
(),
qc
->
get_data_dilation_strides
(),
args
[
2
]
->
get_data_ptr
<
const
float
>
(),
args
[
3
]
->
get_data_ptr
<
const
uint8_t
>
(),
args
[
4
]
->
get_data_ptr
<
const
float
>
(),
args
[
5
]
->
get_data_ptr
<
const
int8_t
>
(),
args
[
6
]
->
get_data_ptr
<
const
float
>
(),
args
[
7
]
->
get_data_ptr
<
const
int32_t
>
());
}
else
if
(
input_element_type
==
element
::
u8
&&
filter_element_type
==
element
::
u8
&&
output_element_type
==
element
::
i32
)
{
reference
::
convolution
<
uint8_t
,
uint8_t
,
int32_t
,
int32_t
>
(
args
[
0
]
->
get_data_ptr
<
const
uint8_t
>
(),
args
[
1
]
->
get_data_ptr
<
const
uint8_t
>
(),
out
[
0
]
->
get_data_ptr
<
int32_t
>
(),
node
.
get_input_shape
(
0
),
node
.
get_input_shape
(
1
),
node
.
get_output_shape
(
0
),
qc
->
get_window_movement_strides
(),
qc
->
get_window_dilation_strides
(),
qc
->
get_padding_below
(),
qc
->
get_padding_above
(),
qc
->
get_data_dilation_strides
(),
args
[
2
]
->
get_data_ptr
<
const
float
>
(),
args
[
3
]
->
get_data_ptr
<
const
uint8_t
>
(),
args
[
4
]
->
get_data_ptr
<
const
float
>
(),
args
[
5
]
->
get_data_ptr
<
const
uint8_t
>
(),
args
[
6
]
->
get_data_ptr
<
const
float
>
(),
args
[
7
]
->
get_data_ptr
<
const
int32_t
>
());
}
else
{
std
::
stringstream
ss
;
ss
<<
"unsupported element type"
;
throw
std
::
runtime_error
(
ss
.
str
());
}
break
;
}
case
OP_TYPEID
:
:
QuantizedAvgPool
:
case
OP_TYPEID
:
:
QuantizedConvolutionBias
:
case
OP_TYPEID
:
:
QuantizedConvolutionBiasAdd
:
case
OP_TYPEID
:
:
QuantizedConvolutionBiasSignedAdd
:
case
OP_TYPEID
:
:
QuantizedConvolutionRelu
:
case
OP_TYPEID
:
:
QuantizedConvolution
:
case
OP_TYPEID
:
:
QuantizedMaxPool
:
case
OP_TYPEID
:
:
QuantizedDotBias
:
case
OP_TYPEID
:
:
QuantizedDot
:
{
throw
unsupported_op
(
"Unsupported op '"
+
node
.
description
()
+
"'."
);
throw
unsupported_op
(
"Unsupported op '"
+
node
.
description
()
+
"' in Interpreter back end."
);
}
case
OP_TYPEID
:
:
Recv
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
memSize
=
element_count
*
sizeof
(
T
);
const
auto
*
op
=
static_cast
<
const
ngraph
::
op
::
Recv
*>
(
&
node
);
int
src_id
=
op
->
get_src_id
();
reference
::
recv
<
T
>
(
args
[
0
]
->
get_data_ptr
<
T
>
(),
node
.
get_input_element_type
(
0
).
get_type_enum
(),
element_count
,
src_id
);
memcpy
(
out
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
0
]
->
get_data_ptr
<
T
>
(),
memSize
);
break
;
}
case
OP_TYPEID
:
:
Range
:
{
throw
unsupported_op
(
"Unsupported op '"
+
node
.
description
()
+
"'"
);
break
;
}
case
OP_TYPEID
:
:
Relu
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
relu
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
ReluBackprop
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
relu_backprop
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
relu_backprop
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
ReplaceSlice
:
{
const
op
::
ReplaceSlice
*
slice
=
static_cast
<
const
op
::
ReplaceSlice
*>
(
&
node
);
reference
::
replace_slice
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
replace_slice
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
1
),
slice
->
get_lower_bounds
(),
slice
->
get_upper_bounds
(),
...
...
@@ -1127,8 +1342,8 @@ private:
case
OP_TYPEID
:
:
Reshape
:
{
const
op
::
Reshape
*
reshape
=
static_cast
<
const
op
::
Reshape
*>
(
&
node
);
gcpu
::
kernel
::
reshape
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
kernel
::
reshape
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
reshape
->
get_input_order
(),
node
.
get_output_shape
(
0
));
...
...
@@ -1137,16 +1352,16 @@ private:
case
OP_TYPEID
:
:
Result
:
{
const
op
::
Result
*
res
=
static_cast
<
const
op
::
Result
*>
(
&
node
);
reference
::
result
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
result
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
shape_size
(
res
->
get_shape
()));
break
;
}
case
OP_TYPEID
:
:
Reverse
:
{
const
op
::
Reverse
*
reverse
=
static_cast
<
const
op
::
Reverse
*>
(
&
node
);
reference
::
reverse
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
reverse
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
reverse
->
get_reversed_axes
());
...
...
@@ -1158,12 +1373,12 @@ private:
if
(
node
.
get_input_element_type
(
1
)
==
element
::
i32
)
{
reference
::
reverse_sequence
<
T
,
int32_t
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
reverse_sequence
<
T
,
int32_t
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
reverse
->
get_batch_axis
(),
reverse
->
get_sequence_axis
(),
static_cast
<
const
int32_t
*>
(
args
[
1
]
));
args
[
1
]
->
get_data_ptr
<
const
int32_t
>
(
));
}
else
{
...
...
@@ -1234,31 +1449,46 @@ private:
case
OP_TYPEID
:
:
Select
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
select
<
T
>
(
static_cast
<
const
char
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
const
T
*>
(
args
[
2
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
select
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
char
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
args
[
2
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Send
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
size_t
memSize
=
element_count
*
sizeof
(
T
);
const
auto
*
op
=
static_cast
<
const
ngraph
::
op
::
Send
*>
(
&
node
);
int
dest_id
=
op
->
get_dest_id
();
reference
::
send
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
node
.
get_input_element_type
(
0
).
get_type_enum
(),
element_count
,
dest_id
);
memcpy
(
out
[
0
]
->
get_data_ptr
<
T
>
(),
args
[
0
]
->
get_data_ptr
<
T
>
(),
memSize
);
break
;
}
case
OP_TYPEID
:
:
ShapeOf
:
{
reference
::
shape_of
(
node
.
get_input_shape
(
0
),
static_cast
<
uint64_t
*>
(
out
[
0
]
));
reference
::
shape_of
(
node
.
get_input_shape
(
0
),
out
[
0
]
->
get_data_ptr
<
uint64_t
>
(
));
break
;
}
case
OP_TYPEID
:
:
Sigmoid
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
sigmoid
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
SigmoidBackprop
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
sigmoid_backprop
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
sigmoid_backprop
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
...
...
@@ -1266,28 +1496,28 @@ private:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
sign
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Sin
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
sin
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Sinh
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
sinh
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Slice
:
{
const
op
::
Slice
*
slice
=
static_cast
<
const
op
::
Slice
*>
(
&
node
);
reference
::
slice
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
slice
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
slice
->
get_lower_bounds
(),
slice
->
get_upper_bounds
(),
...
...
@@ -1298,8 +1528,8 @@ private:
case
OP_TYPEID
:
:
Softmax
:
{
const
op
::
Softmax
*
softmax
=
static_cast
<
const
op
::
Softmax
*>
(
&
node
);
reference
::
softmax
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
softmax
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_output_shape
(
0
),
softmax
->
get_axes
());
break
;
...
...
@@ -1308,7 +1538,7 @@ private:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
sqrt
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
StopGradient
:
{
throw
unsupported_op
(
"Unsupported op 'StopGradient'"
);
...
...
@@ -1316,17 +1546,17 @@ private:
case
OP_TYPEID
:
:
Subtract
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
subtract
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
const
T
*>
(
args
[
1
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
subtract
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
args
[
1
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Sum
:
{
const
op
::
Sum
*
sum
=
static_cast
<
const
op
::
Sum
*>
(
&
node
);
reference
::
sum
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
T
*>
(
out
[
0
]
),
reference
::
sum
<
T
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
sum
->
get_reduction_axes
());
...
...
@@ -1336,14 +1566,14 @@ private:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
tan
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
Tanh
:
{
size_t
element_count
=
shape_size
(
node
.
get_output_shape
(
0
));
reference
::
tanh
<
T
>
(
static_cast
<
const
T
*>
(
args
[
0
]),
static_cast
<
T
*>
(
out
[
0
]
),
element_count
);
args
[
0
]
->
get_data_ptr
<
const
T
>
(),
out
[
0
]
->
get_data_ptr
<
T
>
(
),
element_count
);
break
;
}
case
OP_TYPEID
:
:
TopK
:
...
...
@@ -1351,9 +1581,9 @@ private:
const
op
::
TopK
*
topk
=
static_cast
<
const
op
::
TopK
*>
(
&
node
);
if
(
node
.
get_output_element_type
(
0
)
==
element
::
i64
)
{
reference
::
topk
<
T
,
int64_t
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
int64_t
*>
(
out
[
0
]
),
static_cast
<
T
*>
(
out
[
1
]
),
reference
::
topk
<
T
,
int64_t
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
int64_t
>
(
),
out
[
1
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
topk
->
get_top_k_axis
(),
...
...
@@ -1362,9 +1592,9 @@ private:
}
else
if
(
node
.
get_output_element_type
(
0
)
==
element
::
i32
)
{
reference
::
topk
<
T
,
int32_t
>
(
static_cast
<
const
T
*>
(
args
[
0
]
),
static_cast
<
int32_t
*>
(
out
[
0
]
),
static_cast
<
T
*>
(
out
[
1
]
),
reference
::
topk
<
T
,
int32_t
>
(
args
[
0
]
->
get_data_ptr
<
const
T
>
(
),
out
[
0
]
->
get_data_ptr
<
int32_t
>
(
),
out
[
1
]
->
get_data_ptr
<
T
>
(
),
node
.
get_input_shape
(
0
),
node
.
get_output_shape
(
0
),
topk
->
get_top_k_axis
(),
...
...
@@ -1377,7 +1607,12 @@ private:
}
break
;
}
default
:
throw
unsupported_op
(
"Unsupported op '"
+
node
.
description
()
+
"'"
);
case
OP_TYPEID
:
:
DynBroadcast
:
case
OP_TYPEID
:
:
Transpose
:
case
OP_TYPEID
:
:
DynPad
:
case
OP_TYPEID
:
:
Tile
:
case
OP_TYPEID
:
:
DynReplaceSlice
:
throw
unsupported_op
(
"Unsupported op '"
+
node
.
description
()
+
"'"
);
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic pop
#endif
...
...
src/ngraph/runtime/generic_cpu/kernel/broadcast.hpp
View file @
60252edd
...
...
@@ -140,6 +140,91 @@ namespace ngraph
}
}
template
<
typename
T
>
void
broadcast_5d
(
const
T
*
in
,
T
*
out
,
const
Shape
&
in_shape
,
const
Shape
&
out_shape
,
const
AxisSet
&
broadcast_axes
)
{
size_t
index
[
5
];
size_t
*
out_index
=
0
;
for
(
size_t
i
=
0
;
i
<
5
;
i
++
)
{
if
(
broadcast_axes
.
count
(
i
)
==
0
)
{
out_index
=
&
index
[
i
];
break
;
}
}
for
(
index
[
0
]
=
0
;
index
[
0
]
<
out_shape
[
0
];
++
index
[
0
])
{
for
(
index
[
1
]
=
0
;
index
[
1
]
<
out_shape
[
1
];
++
index
[
1
])
{
for
(
index
[
2
]
=
0
;
index
[
2
]
<
out_shape
[
2
];
++
index
[
2
])
{
for
(
index
[
3
]
=
0
;
index
[
3
]
<
out_shape
[
3
];
++
index
[
3
])
{
for
(
index
[
4
]
=
0
;
index
[
4
]
<
out_shape
[
4
];
++
index
[
4
])
{
out
[
index
[
0
]
*
out_shape
[
1
]
*
out_shape
[
2
]
*
out_shape
[
3
]
*
out_shape
[
4
]
+
index
[
1
]
*
out_shape
[
2
]
*
out_shape
[
3
]
*
out_shape
[
4
]
+
index
[
2
]
*
out_shape
[
3
]
*
out_shape
[
4
]
+
index
[
3
]
*
out_shape
[
4
]
+
index
[
4
]]
=
in
[
*
out_index
];
}
}
}
}
}
}
template
<
typename
T
>
void
broadcast_6d
(
const
T
*
in
,
T
*
out
,
const
Shape
&
in_shape
,
const
Shape
&
out_shape
,
const
AxisSet
&
broadcast_axes
)
{
size_t
index
[
6
];
size_t
*
out_index
=
0
;
for
(
size_t
i
=
0
;
i
<
6
;
i
++
)
{
if
(
broadcast_axes
.
count
(
i
)
==
0
)
{
out_index
=
&
index
[
i
];
break
;
}
}
for
(
index
[
0
]
=
0
;
index
[
0
]
<
out_shape
[
0
];
++
index
[
0
])
{
for
(
index
[
1
]
=
0
;
index
[
1
]
<
out_shape
[
1
];
++
index
[
1
])
{
for
(
index
[
2
]
=
0
;
index
[
2
]
<
out_shape
[
2
];
++
index
[
2
])
{
for
(
index
[
3
]
=
0
;
index
[
3
]
<
out_shape
[
3
];
++
index
[
3
])
{
for
(
index
[
4
]
=
0
;
index
[
4
]
<
out_shape
[
4
];
++
index
[
4
])
{
for
(
index
[
5
]
=
0
;
index
[
5
]
<
out_shape
[
5
];
++
index
[
5
])
{
out
[
index
[
0
]
*
out_shape
[
1
]
*
out_shape
[
2
]
*
out_shape
[
3
]
*
out_shape
[
4
]
*
out_shape
[
5
]
+
index
[
1
]
*
out_shape
[
2
]
*
out_shape
[
3
]
*
out_shape
[
4
]
*
out_shape
[
5
]
+
index
[
2
]
*
out_shape
[
3
]
*
out_shape
[
4
]
*
out_shape
[
5
]
+
index
[
3
]
*
out_shape
[
4
]
*
out_shape
[
5
]
+
index
[
4
]
*
out_shape
[
5
]
+
index
[
5
]]
=
in
[
*
out_index
];
}
}
}
}
}
}
}
template
<
typename
T
>
void
broadcast
(
const
T
*
in
,
T
*
out
,
...
...
@@ -167,6 +252,16 @@ namespace ngraph
case
4
:
broadcast_4d
<
T
>
(
in
,
out
,
in_shape
,
out_shape
,
broadcast_axes
);
break
;
case
5
:
broadcast_5d
<
T
>
(
in
,
out
,
in_shape
,
out_shape
,
broadcast_axes
);
break
;
case
6
:
broadcast_6d
<
T
>
(
in
,
out
,
in_shape
,
out_shape
,
broadcast_axes
);
break
;
default
:
runtime
::
reference
::
broadcast
<
T
>
(
in
,
out
,
in_shape
,
out_shape
,
broadcast_axes
);
break
;
}
}
else
...
...
src/ngraph/runtime/generic_cpu/kernel/reshape.hpp
View file @
60252edd
...
...
@@ -244,10 +244,7 @@ namespace ngraph
case
4
:
reshape_in4
<
T
>
(
in
,
out
,
in_shape
,
in_axis_order
,
out_shape
);
break
;
case
5
:
reshape_in5
<
T
>
(
in
,
out
,
in_shape
,
in_axis_order
,
out_shape
);
break
;
case
6
:
reshape_in6
<
T
>
(
in
,
out
,
in_shape
,
in_axis_order
,
out_shape
);
break
;
default
:
NGRAPH_INFO
<<
"reference::reshape"
;
reference
::
reshape
(
in
,
out
,
in_shape
,
in_axis_order
,
out_shape
);
break
;
default
:
reference
::
reshape
(
in
,
out
,
in_shape
,
in_axis_order
,
out_shape
);
break
;
}
}
}
...
...
src/ngraph/runtime/generic_cpu/kernel/result.hpp
deleted
100644 → 0
View file @
341205cf
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <algorithm>
#include <cmath>
#include <numeric>
#include <vector>
#include "ngraph/shape.hpp"
namespace
ngraph
{
namespace
runtime
{
namespace
gcpu
{
namespace
kernel
{
template
<
typename
T
>
void
result
(
const
T
*
arg
,
T
*
out
,
size_t
count
)
{
memcpy
(
out
,
arg
,
sizeof
(
T
)
*
count
);
}
}
}
}
}
src/ngraph/runtime/generic_cpu/node_wrapper.hpp
View file @
60252edd
...
...
@@ -51,7 +51,7 @@ class ngraph::runtime::gcpu::NodeWrapper
public
:
NodeWrapper
(
const
std
::
shared_ptr
<
const
ngraph
::
Node
>&
node
);
const
Node
&
get_node
()
const
{
return
*
m_node
;
}
std
::
shared_ptr
<
const
Node
>
get_node
()
const
{
return
m_node
;
}
ngraph
::
runtime
::
gcpu
::
OP_TYPEID
get_typeid
()
const
{
return
m_typeid
;
}
private
:
std
::
shared_ptr
<
const
ngraph
::
Node
>
m_node
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment