Unverified Commit 60252edd authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge branch 'master' into ayzhuang/batch_norm_infer_relu_fusion

parents 341205cf 47342339
...@@ -18,10 +18,10 @@ include(ExternalProject) ...@@ -18,10 +18,10 @@ include(ExternalProject)
# Includes blas 3.8.0 in mkldnn # Includes blas 3.8.0 in mkldnn
set(NGRAPH_MKLDNN_SHORT_VERSION 0) set(NGRAPH_MKLDNN_SHORT_VERSION 0)
set(NGRAPH_MKLDNN_FULL_VERSION 0.19.0.0) set(NGRAPH_MKLDNN_FULL_VERSION 0.20.0.0)
set(NGRAPH_MKLDNN_VERSION "v0.19") set(NGRAPH_MKLDNN_VERSION "v0.20")
set(NGRAPH_MKLDNN_SUB_VERSION "2019.0.5.20190502") set(NGRAPH_MKLDNN_SUB_VERSION "2019.0.5.20190502")
set(NGRAPH_MKLDNN_GIT_TAG "027de76") set(NGRAPH_MKLDNN_GIT_TAG "v0.20")
#------------------------------------------------------------------------------ #------------------------------------------------------------------------------
# Fetch and install MKL-DNN # Fetch and install MKL-DNN
......
...@@ -28,16 +28,3 @@ index f10feb20..05f47961 100644 ...@@ -28,16 +28,3 @@ index f10feb20..05f47961 100644
set_property(TARGET ${LIB_NAME} PROPERTY PUBLIC_HEADER ${HEADERS}) set_property(TARGET ${LIB_NAME} PROPERTY PUBLIC_HEADER ${HEADERS})
target_include_directories(${LIB_NAME} PUBLIC target_include_directories(${LIB_NAME} PUBLIC
diff --git a/src/cpu/jit_avx512_common_conv_kernel.cpp b/src/cpu/jit_avx512_common_conv_kernel.cpp
index 1bb98fa43..b8b54401f 100644
--- a/src/cpu/jit_avx512_common_conv_kernel.cpp
+++ b/src/cpu/jit_avx512_common_conv_kernel.cpp
@@ -3055,7 +3055,7 @@ void jit_avx512_common_conv_bwd_weights_kernel_f32::bias_kernel_3d() {
void jit_avx512_common_conv_bwd_weights_kernel_f32
::compute_oh_loop_common()
{
- assert(jcp.harness == harness_mb_reduction);
+ assert(one_of(jcp.harness, harness_mb_reduction, harness_3d_reduction));
int b_pad = jcp.b_pad;
int t_pad = jcp.t_pad;
bool is_dilated = jcp.dilate_h != 0;
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
...@@ -73,11 +71,11 @@ author = 'Intel Corporation' ...@@ -73,11 +71,11 @@ author = 'Intel Corporation'
# built documents. # built documents.
# #
# The short X.Y version. # The short X.Y version.
version = '0.22' version = '0.23'
# The Documentation full version, including alpha/beta/rc tags. Some features # The Documentation full version, including alpha/beta/rc tags. Some features
# available in the latest code will not necessarily be documented first # available in the latest code will not necessarily be documented first
release = '0.22.0' release = '0.23.0'
# The language for content autogenerated by Sphinx. Refer to documentation # The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages. # for a list of supported languages.
......
...@@ -9,11 +9,11 @@ ...@@ -9,11 +9,11 @@
<dt>{{ _('Recent Versions') }}</dt> <dt>{{ _('Recent Versions') }}</dt>
<dd><!-- Until our https://docs.ngraph.ai/ publishing is set up, we link to GitHub --> <dd><!-- Until our https://docs.ngraph.ai/ publishing is set up, we link to GitHub -->
<ul> <ul>
<li><a href="https://github.com/NervanaSystems/ngraph/releases/tag/v0.22.0">0.22</a></li> <li><a href="https://github.com/NervanaSystems/ngraph/releases/tag/v0.23.0">0.23.0</a></li>
<li><a href="https://github.com/NervanaSystems/ngraph/releases/tag/v0.22.0">0.22.0</a></li>
<li><a href="https://github.com/NervanaSystems/ngraph/releases/tag/v0.21.0">0.21.0</a></li> <li><a href="https://github.com/NervanaSystems/ngraph/releases/tag/v0.21.0">0.21.0</a></li>
<li><a href="https://github.com/NervanaSystems/ngraph/releases/tag/v0.20.0">0.20.0</a></li> <li><a href="https://github.com/NervanaSystems/ngraph/releases/tag/v0.20.0">0.20.0</a></li>
<li><a href="https://github.com/NervanaSystems/ngraph/releases/tag/v0.19.0">0.19.0</a></li> <li><a href="https://github.com/NervanaSystems/ngraph/releases/tag/v0.19.0">0.19.0</a></li>
<li><a href="https://github.com/NervanaSystems/ngraph/releases/tag/v0.18.1">0.18.1</a></li>
</ul></dd> </ul></dd>
</dl> </dl>
<dl> <dl>
......
...@@ -6,28 +6,30 @@ Release Notes ...@@ -6,28 +6,30 @@ Release Notes
nGraph is provided as source code, APIs, build scripts, and some binary formats nGraph is provided as source code, APIs, build scripts, and some binary formats
for various Compiler stack configurations and use cases. for various Compiler stack configurations and use cases.
For downloads formatted as ``.zip`` and ``tar.gz``, see
https://github.com/NervanaSystems/ngraph/releases.
This page includes additional documentation updates. This page includes additional documentation updates.
We are pleased to announce the release of version |version|-doc. We are pleased to announce the release of version |version|-doc.
==============================
Core updates for |version| Core updates for |version|
~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ PlaidML support
+ More ONNX ops + More ONNX ops
+ Optimizations + Elementwise divide defaults to Python semantics
+ Don't reseed RNG on each use + GenerateMask seed optional
0.22-doc
--------
+ Initial doc and API for IntelGPU backend. Latest doc updates
+ DynamicBackend API. ~~~~~~~~~~~~~~~~~~
+ Note deprecation of support of MXNet's ``ngraph-mxnet`` PyPI.
+ Noted changes on graph inspection options resultant from PR 3016. + Document new debug tool
+ Added better tips and details to doc-contributor-README. + Note deprecation of MXNet's ``ngraph-mxnet`` PyPI
+ Note default change to `svg` files for graphs and visualization
+ Add more prominent tips for contributors who find the doc-contributor-README
.. important:: Pre-releases (``-rc-0.*``) have newer features, and are less stable. .. important:: Pre-releases (``-rc-0.*``) have newer features, and are less stable.
...@@ -36,8 +38,15 @@ Core updates for |version| ...@@ -36,8 +38,15 @@ Core updates for |version|
Changelog on Previous Releases Changelog on Previous Releases
============================== ==============================
For downloads formatted as ``.zip`` and ``tar.gz``, see 0.22
https://github.com/NervanaSystems/ngraph/releases. ----
+ More ONNX ops
+ Optimizations
+ Don't reseed RNG on each use
+ Initial doc and API for IntelGPU backend
+ DynamicBackend API
0.21 0.21
---- ----
...@@ -51,12 +60,6 @@ https://github.com/NervanaSystems/ngraph/releases. ...@@ -51,12 +60,6 @@ https://github.com/NervanaSystems/ngraph/releases.
+ offset arg for tensor creation is deprecated + offset arg for tensor creation is deprecated
+ static linking support + static linking support
+ Initial test of 0.21-doc + Initial test of 0.21-doc
0.21-doc
--------
Summary of documentation-related changes:
+ Updated :doc:`doc-contributor-README` for new community-based contributions. + Updated :doc:`doc-contributor-README` for new community-based contributions.
+ Added instructions on how to test or display the installed nGraph version. + Added instructions on how to test or display the installed nGraph version.
+ Added instructions on building nGraph bridge (ngraph-bridge). + Added instructions on building nGraph bridge (ngraph-bridge).
...@@ -82,8 +85,6 @@ Summary of documentation-related changes: ...@@ -82,8 +85,6 @@ Summary of documentation-related changes:
0.19 0.19
---- ----
**Download** `0.19.0-rc.2`_
+ More dynamic shape preparation + More dynamic shape preparation
+ Distributed interface factored out + Distributed interface factored out
+ fp16 and bfloat16 types + fp16 and bfloat16 types
...@@ -103,9 +104,6 @@ Summary of documentation-related changes: ...@@ -103,9 +104,6 @@ Summary of documentation-related changes:
0.18 0.18
---- ----
**Download** `0.18.1`_
+ Python formatting issue + Python formatting issue
+ mkl-dnn work-around + mkl-dnn work-around
+ Event tracing improvements + Event tracing improvements
...@@ -118,8 +116,6 @@ Summary of documentation-related changes: ...@@ -118,8 +116,6 @@ Summary of documentation-related changes:
0.17 0.17
---- ----
**Download** `0.17.0-rc.1`_
+ Allow negative padding in more places + Allow negative padding in more places
+ Add code generation for some quantized ops + Add code generation for some quantized ops
+ Preliminary dynamic shape support + Preliminary dynamic shape support
...@@ -131,11 +127,6 @@ Summary of documentation-related changes: ...@@ -131,11 +127,6 @@ Summary of documentation-related changes:
0.16 0.16
---- ----
* **Download**: `0.16.0-rc.3`_
* **Download** `0.16.0-rc.2`_
* **Download** `0.16.0-rc.1`_
+ NodeInput and NodeOutput classes prepare for simplifications of Node + NodeInput and NodeOutput classes prepare for simplifications of Node
+ Test improvements + Test improvements
+ Additional quantization ops + Additional quantization ops
...@@ -143,11 +134,3 @@ Summary of documentation-related changes: ...@@ -143,11 +134,3 @@ Summary of documentation-related changes:
+ Fix memory leak + Fix memory leak
+ Concat optimization + Concat optimization
+ Doc updates + Doc updates
.. _0.20.0-rc.0: https://github.com/NervanaSystems/ngraph/releases/tag/v0.20.0-rc.0_
.. _0.19.0-rc.2: https://github.com/NervanaSystems/ngraph/releases/tag/v0.19.0-rc.2_
.. _0.18.1: https://github.com/NervanaSystems/ngraph/releases/tag/v0.18.1_
.. _0.17.0-rc.1: `https://github.com/NervanaSystems/ngraph/releases/tag/v0.17.0-rc.1
.. _0.16.0-rc.3: https://github.com/NervanaSystems/ngraph/releases/tag/v0.16.0-rc.3
.. _0.16.0-rc.2: https://github.com/NervanaSystems/ngraph/releases/tag/v0.16.0-rc.2
.. _0.16.0-rc.1: https://github.com/NervanaSystems/ngraph/releases/tag/v0.16.0-rc.1
pytest pytest
tox tox
pydocstyle==3.0.0
flake8 flake8
flake8-commas flake8-commas
flake8-comprehensions flake8-comprehensions
......
...@@ -370,7 +370,6 @@ set (SRC ...@@ -370,7 +370,6 @@ set (SRC
op/util/index_reduction.hpp op/util/index_reduction.hpp
op/util/logical_reduction.cpp op/util/logical_reduction.cpp
op/util/logical_reduction.hpp op/util/logical_reduction.hpp
op/util/reshape.hpp
op/util/rnn_cell_base.cpp op/util/rnn_cell_base.cpp
op/util/rnn_cell_base.hpp op/util/rnn_cell_base.hpp
op/util/unary_elementwise_arithmetic.cpp op/util/unary_elementwise_arithmetic.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cstddef>
#include <memory>
#include <vector>
#include "ngraph/builder/reshape.hpp"
#include "ngraph/node.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace op
{
namespace util
{
/// \brief Change shape of input tensor.
///
/// \param[in] node The node producing the tensor to be reshaped.
/// \param[in] shape The new shape for input tensor.
///
/// \return The node representing a Reshape operation.
///
std::shared_ptr<ngraph::Node> reshape(const std::shared_ptr<ngraph::Node>& node,
const Shape& shape)
{
return builder::reshape(node, shape);
}
/// \brief Permute axes according to specified axes_order parameter.
///
/// \param node The node which axes we want to permute.
/// \param axes_order The permutation of node tensor axes.
///
/// \return: New node with permuted axes.
std::shared_ptr<ngraph::Node> reorder_axes(const std::shared_ptr<ngraph::Node>& node,
std::vector<std::size_t> axes_order)
{
return builder::reorder_axes(node, axes_order);
}
/// \brief Return transposed tensor (with axes in reversed order).
///
/// \param node Input tensor we want to transpose
///
/// \return: New node with reversed dimensions.
std::shared_ptr<ngraph::Node> transpose(const std::shared_ptr<ngraph::Node>& node)
{
return builder::transpose(node);
}
/// \brief Flatten the input tensor into a 2D matrix.
///
/// \param node The tensor to be flattened.
/// \param axis The axis dividing shape.
///
/// \return The new node will be a 2D matrix representing the flattened input node.
std::shared_ptr<ngraph::Node> flatten(const std::shared_ptr<ngraph::Node>& node,
int axis)
{
return builder::flatten(node, axis);
}
} // namespace util
} // namespace op
} // namespace ngraph
...@@ -49,7 +49,8 @@ public: ...@@ -49,7 +49,8 @@ public:
} }
}; };
std::unique_ptr<ngraph::runtime::Allocator> ngraph::runtime::create_default_allocator() ngraph::runtime::Allocator* ngraph::runtime::get_default_allocator()
{ {
return std::unique_ptr<DefaultAllocator>(new DefaultAllocator()); static std::unique_ptr<DefaultAllocator> allocator(new DefaultAllocator());
return allocator.get();
} }
...@@ -30,7 +30,7 @@ namespace ngraph ...@@ -30,7 +30,7 @@ namespace ngraph
class DefaultAllocator; class DefaultAllocator;
/// \brief Create a default allocator that calls into system /// \brief Create a default allocator that calls into system
/// allocation libraries /// allocation libraries
std::unique_ptr<Allocator> create_default_allocator(); ngraph::runtime::Allocator* get_default_allocator();
} }
} }
......
...@@ -185,7 +185,7 @@ runtime::Allocator* runtime::cpu::CPU_Backend::get_host_memory_allocator() ...@@ -185,7 +185,7 @@ runtime::Allocator* runtime::cpu::CPU_Backend::get_host_memory_allocator()
{ {
if (!m_allocator) if (!m_allocator)
{ {
m_allocator = create_default_allocator(); return runtime::get_default_allocator();
} }
return m_allocator.get(); return m_allocator.get();
} }
......
...@@ -15,10 +15,10 @@ ...@@ -15,10 +15,10 @@
# ****************************************************************************** # ******************************************************************************
if (NGRAPH_GENERIC_CPU_ENABLE) if (NGRAPH_GENERIC_CPU_ENABLE)
find_package(OpenMP) # find_package(OpenMP)
if (OPENMP_FOUND) # if (OPENMP_FOUND)
add_compile_options(${OpenMP_CXX_FLAGS}) # add_compile_options(${OpenMP_CXX_FLAGS})
endif() # endif()
add_library(gcpu_backend SHARED gcpu_backend.cpp gcpu_executable.cpp node_wrapper.cpp) add_library(gcpu_backend SHARED gcpu_backend.cpp gcpu_executable.cpp node_wrapper.cpp)
if(NGRAPH_LIB_VERSIONING_ENABLE) if(NGRAPH_LIB_VERSIONING_ENABLE)
set_target_properties(gcpu_backend PROPERTIES set_target_properties(gcpu_backend PROPERTIES
......
...@@ -52,14 +52,14 @@ runtime::gcpu::GCPUBackend::GCPUBackend(const vector<string>& unsupported_op_nam ...@@ -52,14 +52,14 @@ runtime::gcpu::GCPUBackend::GCPUBackend(const vector<string>& unsupported_op_nam
shared_ptr<runtime::Tensor> runtime::gcpu::GCPUBackend::create_tensor(const element::Type& type, shared_ptr<runtime::Tensor> runtime::gcpu::GCPUBackend::create_tensor(const element::Type& type,
const Shape& shape) const Shape& shape)
{ {
return make_shared<runtime::HostTensor>(type, shape, this); return make_shared<runtime::HostTensor>(type, shape);
} }
shared_ptr<runtime::Tensor> runtime::gcpu::GCPUBackend::create_tensor(const element::Type& type, shared_ptr<runtime::Tensor> runtime::gcpu::GCPUBackend::create_tensor(const element::Type& type,
const Shape& shape, const Shape& shape,
void* memory_pointer) void* memory_pointer)
{ {
return make_shared<runtime::HostTensor>(type, shape, memory_pointer, this); return make_shared<runtime::HostTensor>(type, shape, memory_pointer);
} }
shared_ptr<runtime::Executable> shared_ptr<runtime::Executable>
......
...@@ -15,17 +15,22 @@ ...@@ -15,17 +15,22 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/runtime/generic_cpu/gcpu_executable.hpp" #include "ngraph/runtime/generic_cpu/gcpu_executable.hpp"
#include "ngraph/cpio.hpp"
#include "ngraph/descriptor/layout/dense_tensor_layout.hpp" #include "ngraph/descriptor/layout/dense_tensor_layout.hpp"
#include "ngraph/except.hpp" #include "ngraph/except.hpp"
#include "ngraph/op/convert.hpp" #include "ngraph/op/convert.hpp"
#include "ngraph/op/select.hpp" #include "ngraph/op/select.hpp"
#include "ngraph/op/util/binary_elementwise_comparison.hpp" #include "ngraph/op/util/binary_elementwise_comparison.hpp"
#include "ngraph/pass/assign_layout.hpp" #include "ngraph/pass/assign_layout.hpp"
#include "ngraph/pass/core_fusion.hpp"
#include "ngraph/pass/fused_op_decomposition.hpp"
#include "ngraph/pass/implicit_broadcast_elimination.hpp"
#include "ngraph/pass/like_replacement.hpp" #include "ngraph/pass/like_replacement.hpp"
#include "ngraph/pass/liveness.hpp" #include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp" #include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp" #include "ngraph/pass/memory_layout.hpp"
#include "ngraph/runtime/backend_manager.hpp" #include "ngraph/runtime/backend_manager.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
using namespace std; using namespace std;
...@@ -35,21 +40,35 @@ using descriptor::layout::DenseTensorLayout; ...@@ -35,21 +40,35 @@ using descriptor::layout::DenseTensorLayout;
runtime::gcpu::GCPUExecutable::GCPUExecutable(const shared_ptr<Function>& function, runtime::gcpu::GCPUExecutable::GCPUExecutable(const shared_ptr<Function>& function,
bool enable_performance_collection) bool enable_performance_collection)
: m_is_compiled{true}
, m_performance_counters_enabled{enable_performance_collection}
{ {
{ m_function = clone_function(*function);
m_is_compiled = true;
pass::Manager pass_manager; pass::Manager pass_manager;
pass_manager.register_pass<pass::LikeReplacement>(); pass_manager.register_pass<pass::LikeReplacement>();
pass_manager.register_pass<pass::FusedOpDecomposition>();
pass_manager.register_pass<pass::ImplicitBroadcastElimination>();
pass_manager.register_pass<pass::AssignLayout<DenseTensorLayout>>(); pass_manager.register_pass<pass::AssignLayout<DenseTensorLayout>>();
pass_manager.register_pass<pass::Liveness>(); pass_manager.register_pass<pass::Liveness>();
pass_manager.run_passes(function); pass_manager.run_passes(m_function);
for (const shared_ptr<Node>& node : function->get_ordered_ops()) for (const shared_ptr<Node>& node : m_function->get_ordered_ops())
{ {
m_wrapped_nodes.emplace_back(node); m_wrapped_nodes.emplace_back(node);
} }
set_parameters_and_results(*m_function);
}
runtime::gcpu::GCPUExecutable::GCPUExecutable(const std::string& model_string)
: m_is_compiled{true}
, m_performance_counters_enabled{false}
{
m_function = deserialize(model_string);
for (const shared_ptr<Node>& node : m_function->get_ordered_ops())
{
m_wrapped_nodes.emplace_back(node);
} }
set_parameters_and_results(*function); set_parameters_and_results(*m_function);
} }
bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor>>& outputs, bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor>>& outputs,
...@@ -82,7 +101,7 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor ...@@ -82,7 +101,7 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
{ {
for (size_t i = 0; i < param->get_output_size(); ++i) for (size_t i = 0; i < param->get_output_size(); ++i)
{ {
descriptor::Tensor* tensor = param->get_output_tensor_ptr(i).get(); descriptor::Tensor* tensor = &param->output(i).get_tensor();
tensor_map.insert({tensor, func_inputs[input_count++]}); tensor_map.insert({tensor, func_inputs[input_count++]});
} }
} }
...@@ -95,14 +114,14 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor ...@@ -95,14 +114,14 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
{ {
throw ngraph_error("One of function's outputs isn't op::Result"); throw ngraph_error("One of function's outputs isn't op::Result");
} }
descriptor::Tensor* tensor = output->get_output_tensor_ptr(0).get(); descriptor::Tensor* tensor = &output->output(0).get_tensor();
tensor_map.insert({tensor, func_outputs[output_count]}); tensor_map.insert({tensor, func_outputs[output_count]});
} }
// for each ordered op in the graph // for each ordered op in the graph
for (const NodeWrapper& wrapped : m_wrapped_nodes) for (const NodeWrapper& wrapped : m_wrapped_nodes)
{ {
const Node* op = &wrapped.get_node(); auto op = wrapped.get_node();
auto type_id = wrapped.get_typeid(); auto type_id = wrapped.get_typeid();
if (type_id == OP_TYPEID::Parameter) if (type_id == OP_TYPEID::Parameter)
{ {
...@@ -111,9 +130,9 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor ...@@ -111,9 +130,9 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
// get op inputs from map // get op inputs from map
vector<shared_ptr<HostTensor>> op_inputs; vector<shared_ptr<HostTensor>> op_inputs;
for (const descriptor::Input& input : op->get_inputs()) for (auto input : op->inputs())
{ {
descriptor::Tensor* tensor = input.get_output().get_tensor_ptr().get(); descriptor::Tensor* tensor = &input.get_tensor();
op_inputs.push_back(tensor_map.at(tensor)); op_inputs.push_back(tensor_map.at(tensor));
} }
...@@ -121,14 +140,14 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor ...@@ -121,14 +140,14 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
vector<shared_ptr<HostTensor>> op_outputs; vector<shared_ptr<HostTensor>> op_outputs;
for (size_t i = 0; i < op->get_output_size(); ++i) for (size_t i = 0; i < op->get_output_size(); ++i)
{ {
descriptor::Tensor* tensor = op->get_output_tensor_ptr(i).get(); descriptor::Tensor* tensor = &op->output(i).get_tensor();
shared_ptr<HostTensor> host_tensor; shared_ptr<HostTensor> host_tensor;
auto it = tensor_map.find(tensor); auto it = tensor_map.find(tensor);
if (it == tensor_map.end()) if (it == tensor_map.end())
{ {
const Shape& shape = op->get_output_shape(i); const Shape& shape = op->get_output_shape(i);
const element::Type& type = op->get_output_element_type(i); const element::Type& type = op->get_output_element_type(i);
string name = op->get_output_tensor(i).get_name(); string name = op->output(i).get_tensor().get_name();
host_tensor = make_shared<runtime::HostTensor>(type, shape, name); host_tensor = make_shared<runtime::HostTensor>(type, shape, name);
tensor_map.insert({tensor, host_tensor}); tensor_map.insert({tensor, host_tensor});
} }
...@@ -177,7 +196,7 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor ...@@ -177,7 +196,7 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
} }
if (m_nan_check_enabled) if (m_nan_check_enabled)
{ {
perform_nan_check(op_outputs, op); perform_nan_check(op_outputs, op.get());
} }
} }
...@@ -186,19 +205,9 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor ...@@ -186,19 +205,9 @@ bool runtime::gcpu::GCPUExecutable::call(const vector<shared_ptr<runtime::Tensor
void runtime::gcpu::GCPUExecutable::generate_calls(const element::Type& type, void runtime::gcpu::GCPUExecutable::generate_calls(const element::Type& type,
const NodeWrapper& op, const NodeWrapper& op,
const vector<shared_ptr<HostTensor>>& outputs, const vector<shared_ptr<HostTensor>>& out,
const vector<shared_ptr<HostTensor>>& inputs) const vector<shared_ptr<HostTensor>>& in)
{ {
vector<void*> out;
vector<const void*> in;
for (auto t : outputs)
{
out.push_back(t->get_data_ptr());
}
for (auto t : inputs)
{
in.push_back(t->get_data_ptr());
}
stringstream ss; stringstream ss;
switch (type.get_type_enum()) switch (type.get_type_enum())
{ {
...@@ -216,7 +225,8 @@ void runtime::gcpu::GCPUExecutable::generate_calls(const element::Type& type, ...@@ -216,7 +225,8 @@ void runtime::gcpu::GCPUExecutable::generate_calls(const element::Type& type,
case element::Type_t::undefined: case element::Type_t::undefined:
case element::Type_t::dynamic: case element::Type_t::dynamic:
case element::Type_t::bf16: case element::Type_t::bf16:
ss << "unsupported element type " << type << " op " << op.get_node().get_name(); case element::Type_t::f16:
ss << "unsupported element type " << type << " op " << op.get_node()->get_name();
throw ngraph_error(ss.str()); throw ngraph_error(ss.str());
} }
} }
...@@ -229,11 +239,9 @@ void runtime::gcpu::GCPUExecutable::set_nan_check(bool enable) ...@@ -229,11 +239,9 @@ void runtime::gcpu::GCPUExecutable::set_nan_check(bool enable)
vector<runtime::PerformanceCounter> runtime::gcpu::GCPUExecutable::get_performance_data() const vector<runtime::PerformanceCounter> runtime::gcpu::GCPUExecutable::get_performance_data() const
{ {
vector<runtime::PerformanceCounter> rc; vector<runtime::PerformanceCounter> rc;
for (const pair<const Node*, stopwatch> p : m_timer_map) for (const pair<shared_ptr<const Node>, stopwatch> p : m_timer_map)
{ {
rc.emplace_back(p.first->get_name().c_str(), rc.emplace_back(p.first, p.second.get_total_microseconds(), p.second.get_call_count());
p.second.get_total_microseconds(),
p.second.get_call_count());
} }
return rc; return rc;
} }
...@@ -286,3 +294,12 @@ void runtime::gcpu::GCPUExecutable::perform_nan_check(const vector<shared_ptr<Ho ...@@ -286,3 +294,12 @@ void runtime::gcpu::GCPUExecutable::perform_nan_check(const vector<shared_ptr<Ho
arg_number++; arg_number++;
} }
} }
void runtime::gcpu::GCPUExecutable::save(ostream& out)
{
cpio::Writer writer(out);
string si = "INTERPRETER Save File 1.0";
writer.write("save_info", si.data(), si.size());
string model = serialize(m_function, 0);
writer.write("model", model.data(), model.size());
}
...@@ -17,24 +17,31 @@ ...@@ -17,24 +17,31 @@
#pragma once #pragma once
#include <initializer_list> #include <initializer_list>
#include <iostream>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <vector> #include <vector>
#include "ngraph/op/all.hpp" #include "ngraph/op/all.hpp"
#include "ngraph/op/allreduce.hpp"
#include "ngraph/op/any.hpp" #include "ngraph/op/any.hpp"
#include "ngraph/op/argmax.hpp" #include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp" #include "ngraph/op/argmin.hpp"
#include "ngraph/op/avg_pool.hpp" #include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/batch_norm.hpp" #include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/op/broadcast_distributed.hpp"
#include "ngraph/op/concat.hpp" #include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp" #include "ngraph/op/convolution.hpp"
#include "ngraph/op/dequantize.hpp" #include "ngraph/op/dequantize.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/dot.hpp" #include "ngraph/op/dot.hpp"
#include "ngraph/op/embedding_lookup.hpp" #include "ngraph/op/embedding_lookup.hpp"
#include "ngraph/op/experimental/batch_mat_mul.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_pad.hpp"
#include "ngraph/op/experimental/generate_mask.hpp" #include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/experimental/shape_of.hpp" #include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/gather.hpp" #include "ngraph/op/gather.hpp"
...@@ -48,11 +55,14 @@ ...@@ -48,11 +55,14 @@
#include "ngraph/op/passthrough.hpp" #include "ngraph/op/passthrough.hpp"
#include "ngraph/op/product.hpp" #include "ngraph/op/product.hpp"
#include "ngraph/op/quantize.hpp" #include "ngraph/op/quantize.hpp"
#include "ngraph/op/quantized_convolution.hpp"
#include "ngraph/op/recv.hpp"
#include "ngraph/op/replace_slice.hpp" #include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/result.hpp" #include "ngraph/op/result.hpp"
#include "ngraph/op/reverse.hpp" #include "ngraph/op/reverse.hpp"
#include "ngraph/op/reverse_sequence.hpp" #include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/send.hpp"
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
#include "ngraph/op/softmax.hpp" #include "ngraph/op/softmax.hpp"
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
...@@ -64,7 +74,6 @@ ...@@ -64,7 +74,6 @@
#include "ngraph/runtime/generic_cpu/kernel/reshape.hpp" #include "ngraph/runtime/generic_cpu/kernel/reshape.hpp"
#include "ngraph/runtime/generic_cpu/node_wrapper.hpp" #include "ngraph/runtime/generic_cpu/node_wrapper.hpp"
#include "ngraph/runtime/host_tensor.hpp" #include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/interpreter/node_wrapper.hpp"
#include "ngraph/runtime/reference/abs.hpp" #include "ngraph/runtime/reference/abs.hpp"
#include "ngraph/runtime/reference/acos.hpp" #include "ngraph/runtime/reference/acos.hpp"
#include "ngraph/runtime/reference/add.hpp" #include "ngraph/runtime/reference/add.hpp"
...@@ -77,7 +86,9 @@ ...@@ -77,7 +86,9 @@
#include "ngraph/runtime/reference/asin.hpp" #include "ngraph/runtime/reference/asin.hpp"
#include "ngraph/runtime/reference/atan.hpp" #include "ngraph/runtime/reference/atan.hpp"
#include "ngraph/runtime/reference/avg_pool.hpp" #include "ngraph/runtime/reference/avg_pool.hpp"
#include "ngraph/runtime/reference/batch_mat_mul.hpp"
#include "ngraph/runtime/reference/batch_norm.hpp" #include "ngraph/runtime/reference/batch_norm.hpp"
#include "ngraph/runtime/reference/broadcast.hpp"
#include "ngraph/runtime/reference/broadcast_distributed.hpp" #include "ngraph/runtime/reference/broadcast_distributed.hpp"
#include "ngraph/runtime/reference/ceiling.hpp" #include "ngraph/runtime/reference/ceiling.hpp"
#include "ngraph/runtime/reference/concat.hpp" #include "ngraph/runtime/reference/concat.hpp"
...@@ -89,8 +100,10 @@ ...@@ -89,8 +100,10 @@
#include "ngraph/runtime/reference/cosh.hpp" #include "ngraph/runtime/reference/cosh.hpp"
#include "ngraph/runtime/reference/dequantize.hpp" #include "ngraph/runtime/reference/dequantize.hpp"
#include "ngraph/runtime/reference/divide.hpp" #include "ngraph/runtime/reference/divide.hpp"
#include "ngraph/runtime/reference/dot.hpp"
#include "ngraph/runtime/reference/embedding_lookup.hpp" #include "ngraph/runtime/reference/embedding_lookup.hpp"
#include "ngraph/runtime/reference/equal.hpp" #include "ngraph/runtime/reference/equal.hpp"
#include "ngraph/runtime/reference/erf.hpp"
#include "ngraph/runtime/reference/exp.hpp" #include "ngraph/runtime/reference/exp.hpp"
#include "ngraph/runtime/reference/floor.hpp" #include "ngraph/runtime/reference/floor.hpp"
#include "ngraph/runtime/reference/gather.hpp" #include "ngraph/runtime/reference/gather.hpp"
...@@ -117,14 +130,17 @@ ...@@ -117,14 +130,17 @@
#include "ngraph/runtime/reference/power.hpp" #include "ngraph/runtime/reference/power.hpp"
#include "ngraph/runtime/reference/product.hpp" #include "ngraph/runtime/reference/product.hpp"
#include "ngraph/runtime/reference/quantize.hpp" #include "ngraph/runtime/reference/quantize.hpp"
#include "ngraph/runtime/reference/recv.hpp"
#include "ngraph/runtime/reference/relu.hpp" #include "ngraph/runtime/reference/relu.hpp"
#include "ngraph/runtime/reference/replace_slice.hpp" #include "ngraph/runtime/reference/replace_slice.hpp"
#include "ngraph/runtime/reference/reshape.hpp"
#include "ngraph/runtime/reference/result.hpp" #include "ngraph/runtime/reference/result.hpp"
#include "ngraph/runtime/reference/reverse.hpp" #include "ngraph/runtime/reference/reverse.hpp"
#include "ngraph/runtime/reference/reverse_sequence.hpp" #include "ngraph/runtime/reference/reverse_sequence.hpp"
#include "ngraph/runtime/reference/scatter_add.hpp" #include "ngraph/runtime/reference/scatter_add.hpp"
#include "ngraph/runtime/reference/scatter_nd_add.hpp" #include "ngraph/runtime/reference/scatter_nd_add.hpp"
#include "ngraph/runtime/reference/select.hpp" #include "ngraph/runtime/reference/select.hpp"
#include "ngraph/runtime/reference/send.hpp"
#include "ngraph/runtime/reference/shape_of.hpp" #include "ngraph/runtime/reference/shape_of.hpp"
#include "ngraph/runtime/reference/sigmoid.hpp" #include "ngraph/runtime/reference/sigmoid.hpp"
#include "ngraph/runtime/reference/sign.hpp" #include "ngraph/runtime/reference/sign.hpp"
...@@ -134,6 +150,7 @@ ...@@ -134,6 +150,7 @@
#include "ngraph/runtime/reference/softmax.hpp" #include "ngraph/runtime/reference/softmax.hpp"
#include "ngraph/runtime/reference/sqrt.hpp" #include "ngraph/runtime/reference/sqrt.hpp"
#include "ngraph/runtime/reference/subtract.hpp" #include "ngraph/runtime/reference/subtract.hpp"
#include "ngraph/runtime/reference/sum.hpp"
#include "ngraph/runtime/reference/tan.hpp" #include "ngraph/runtime/reference/tan.hpp"
#include "ngraph/runtime/reference/tanh.hpp" #include "ngraph/runtime/reference/tanh.hpp"
#include "ngraph/runtime/reference/topk.hpp" #include "ngraph/runtime/reference/topk.hpp"
...@@ -154,6 +171,8 @@ namespace ngraph ...@@ -154,6 +171,8 @@ namespace ngraph
class ngraph::runtime::gcpu::GCPUExecutable : public Executable class ngraph::runtime::gcpu::GCPUExecutable : public Executable
{ {
friend class GCPUBackend;
public: public:
GCPUExecutable(const std::shared_ptr<Function>& function, GCPUExecutable(const std::shared_ptr<Function>& function,
bool enable_performance_collection = false); bool enable_performance_collection = false);
...@@ -161,20 +180,25 @@ public: ...@@ -161,20 +180,25 @@ public:
bool call(const std::vector<std::shared_ptr<Tensor>>& outputs, bool call(const std::vector<std::shared_ptr<Tensor>>& outputs,
const std::vector<std::shared_ptr<Tensor>>& intputs) override; const std::vector<std::shared_ptr<Tensor>>& intputs) override;
virtual void save(std::ostream& output_stream) override;
void set_nan_check(bool enable); void set_nan_check(bool enable);
std::vector<PerformanceCounter> get_performance_data() const override; std::vector<PerformanceCounter> get_performance_data() const override;
private: private:
GCPUExecutable(const std::string& model_string);
int get_alignment() const { return 64; }
bool m_is_compiled = false; bool m_is_compiled = false;
bool m_nan_check_enabled = false; bool m_nan_check_enabled = false;
bool m_performance_counters_enabled = false; bool m_performance_counters_enabled = false;
std::unordered_map<const Node*, stopwatch> m_timer_map; std::shared_ptr<Function> m_function;
std::unordered_map<std::shared_ptr<const Node>, stopwatch> m_timer_map;
std::vector<NodeWrapper> m_wrapped_nodes; std::vector<NodeWrapper> m_wrapped_nodes;
std::unordered_map<const Node*, std::shared_ptr<RNGState>> m_states; std::unordered_map<const Node*, std::shared_ptr<RNGState>> m_states;
std::set<std::string> m_unsupported_op_name_list; std::set<std::string> m_unsupported_op_name_list;
int get_alignment() const { return 64; }
static void perform_nan_check(const std::vector<std::shared_ptr<HostTensor>>&, static void perform_nan_check(const std::vector<std::shared_ptr<HostTensor>>&,
const Node* op = nullptr); const Node* op = nullptr);
...@@ -185,11 +209,10 @@ private: ...@@ -185,11 +209,10 @@ private:
template <typename T> template <typename T>
void op_engine(const NodeWrapper& node_wrapper, void op_engine(const NodeWrapper& node_wrapper,
const std::vector<void*>& out, const std::vector<std::shared_ptr<HostTensor>>& out,
const std::vector<const void*>& args) const std::vector<std::shared_ptr<HostTensor>>& args)
{ {
const Node& node = node_wrapper.get_node(); const Node& node = *node_wrapper.get_node();
std::string node_op = node.description();
// We want to check that every OP_TYPEID enumeration is included in the list. // We want to check that every OP_TYPEID enumeration is included in the list.
// These GCC flags enable compile-time checking so that if an enumeration // These GCC flags enable compile-time checking so that if an enumeration
...@@ -206,30 +229,30 @@ private: ...@@ -206,30 +229,30 @@ private:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::abs<T>( reference::abs<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Acos: case OP_TYPEID::Acos:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::acos<T>( reference::acos<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Add: case OP_TYPEID::Add:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::add<T>(static_cast<const T*>(args[0]), reference::add<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::All: case OP_TYPEID::All:
{ {
const op::All* all = static_cast<const op::All*>(&node); const op::All* all = static_cast<const op::All*>(&node);
reference::all(static_cast<const char*>(args[0]), reference::all(args[0]->get_data_ptr<const char>(),
static_cast<char*>(out[0]), out[0]->get_data_ptr<char>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
all->get_reduction_axes()); all->get_reduction_axes());
...@@ -237,26 +260,29 @@ private: ...@@ -237,26 +260,29 @@ private:
} }
case OP_TYPEID::AllReduce: case OP_TYPEID::AllReduce:
{ {
reference::allreduce<T>(static_cast<T*>(const_cast<void*>(args[0])), const ngraph::op::AllReduce* allreduce =
static_cast<T*>(out[0]), static_cast<const ngraph::op::AllReduce*>(&node);
node.get_input_element_type(0), reference::allreduce<T>(args[0]->get_data_ptr<T>(),
out[0]->get_data_ptr<T>(),
node.get_input_element_type(0).get_type_enum(),
allreduce->get_reduce_type(),
static_cast<int>(shape_size(node.get_input_shape(0)))); static_cast<int>(shape_size(node.get_input_shape(0))));
break; break;
} }
case OP_TYPEID::And: case OP_TYPEID::And:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::logical_and(static_cast<const T*>(args[0]), reference::logical_and(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::Any: case OP_TYPEID::Any:
{ {
const op::Any* any = static_cast<const op::Any*>(&node); const op::Any* any = static_cast<const op::Any*>(&node);
reference::any(static_cast<const char*>(args[0]), reference::any(args[0]->get_data_ptr<const char>(),
static_cast<char*>(out[0]), out[0]->get_data_ptr<char>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
any->get_reduction_axes()); any->get_reduction_axes());
...@@ -268,16 +294,16 @@ private: ...@@ -268,16 +294,16 @@ private:
auto element_type = node.get_output_element_type(0); auto element_type = node.get_output_element_type(0);
if (element_type == element::i64) if (element_type == element::i64)
{ {
reference::argmin<T, int64_t>(static_cast<const T*>(args[0]), reference::argmin<T, int64_t>(args[0]->get_data_ptr<const T>(),
static_cast<int64_t*>(out[0]), out[0]->get_data_ptr<int64_t>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
argmin->get_reduction_axis()); argmin->get_reduction_axis());
} }
else if (element_type == element::i32) else if (element_type == element::i32)
{ {
reference::argmin<T, int32_t>(static_cast<const T*>(args[0]), reference::argmin<T, int32_t>(args[0]->get_data_ptr<const T>(),
static_cast<int32_t*>(out[0]), out[0]->get_data_ptr<int32_t>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
argmin->get_reduction_axis()); argmin->get_reduction_axis());
...@@ -294,16 +320,16 @@ private: ...@@ -294,16 +320,16 @@ private:
auto element_type = node.get_output_element_type(0); auto element_type = node.get_output_element_type(0);
if (element_type == element::i64) if (element_type == element::i64)
{ {
reference::argmax<T, int64_t>(static_cast<const T*>(args[0]), reference::argmax<T, int64_t>(args[0]->get_data_ptr<const T>(),
static_cast<int64_t*>(out[0]), out[0]->get_data_ptr<int64_t>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
argmax->get_reduction_axis()); argmax->get_reduction_axis());
} }
else if (element_type == element::i32) else if (element_type == element::i32)
{ {
reference::argmax<T, int32_t>(static_cast<const T*>(args[0]), reference::argmax<T, int32_t>(args[0]->get_data_ptr<const T>(),
static_cast<int32_t*>(out[0]), out[0]->get_data_ptr<int32_t>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
argmax->get_reduction_axis()); argmax->get_reduction_axis());
...@@ -318,22 +344,22 @@ private: ...@@ -318,22 +344,22 @@ private:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::asin<T>( reference::asin<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Atan: case OP_TYPEID::Atan:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::atan<T>( reference::atan<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::AvgPool: case OP_TYPEID::AvgPool:
{ {
const op::AvgPool* avg_pool = static_cast<const op::AvgPool*>(&node); const op::AvgPool* avg_pool = static_cast<const op::AvgPool*>(&node);
reference::avg_pool<T>(static_cast<const T*>(args[0]), reference::avg_pool<T>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
avg_pool->get_window_shape(), avg_pool->get_window_shape(),
...@@ -345,18 +371,30 @@ private: ...@@ -345,18 +371,30 @@ private:
} }
case OP_TYPEID::GenerateMask: case OP_TYPEID::GenerateMask:
{ {
bool use_seed = static_cast<bool>(args[2]->get_data_ptr<const int32_t>()[0]);
if (m_states.count(&node) == 0) if (m_states.count(&node) == 0)
{ {
const op::GenerateMask* gm = static_cast<const op::GenerateMask*>(&node); const op::GenerateMask* gm = static_cast<const op::GenerateMask*>(&node);
auto seed = use_seed ? gm->get_seed() : 0;
m_states[&node] = std::unique_ptr<ngraph::RNGState>( m_states[&node] = std::unique_ptr<ngraph::RNGState>(
ngraph::RNGState::create_rng_state(gm->get_seed(), gm->get_probability())); ngraph::RNGState::create_rng_state(seed, gm->get_probability()));
} }
bool training = static_cast<bool>(static_cast<const T*>(args[0])[0]); bool training = static_cast<bool>(args[0]->get_data_ptr<const T>()[0]);
auto state = m_states.at(&node).get(); auto state = m_states.at(&node).get();
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
if (!use_seed)
{
reference::generate_mask<T>( reference::generate_mask<T>(
reinterpret_cast<T*>(out[0]), element_count, state, training); out[0]->get_data_ptr<T>(), element_count, state, training);
}
else
{
uint64_t seed = static_cast<uint64_t>(args[3]->get_data_ptr<const T>()[0]);
double prob = static_cast<double>(args[4]->get_data_ptr<const T>()[0]);
reference::generate_mask_no_state<T>(
out[0]->get_data_ptr<T>(), element_count, training, seed, prob);
}
break; break;
} }
case OP_TYPEID::GetOutputElement: case OP_TYPEID::GetOutputElement:
...@@ -366,20 +404,31 @@ private: ...@@ -366,20 +404,31 @@ private:
size_t n = get_output_element->get_n(); size_t n = get_output_element->get_n();
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
size_t num_bytes = element_count * node.get_output_element_type(0).size(); size_t num_bytes = element_count * node.get_output_element_type(0).size();
std::memcpy(static_cast<T*>(out[0]), args[n], num_bytes); std::memcpy(out[0]->get_data_ptr<T>(), args[n]->get_data_ptr<T>(), num_bytes);
break;
}
case OP_TYPEID::BatchMatMul:
{
reference::batch_mat_mul(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0));
break; break;
} }
case OP_TYPEID::BatchNormTraining: case OP_TYPEID::BatchNormTraining:
{ {
const ngraph::op::BatchNormTraining* bn = const ngraph::op::BatchNormTraining* bn =
static_cast<const ngraph::op::BatchNormTraining*>(&node); static_cast<const ngraph::op::BatchNormTraining*>(&node);
reference::batch_norm_training<T>(bn->get_eps_value(), reference::batch_norm_training<T>(bn->get_eps_value(),
static_cast<const T*>(args[0]), args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<const T*>(args[2]), args[2]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
static_cast<T*>(out[1]), out[1]->get_data_ptr<T>(),
static_cast<T*>(out[2]), out[2]->get_data_ptr<T>(),
node.get_input_shape(2)); node.get_input_shape(2));
break; break;
} }
...@@ -388,12 +437,12 @@ private: ...@@ -388,12 +437,12 @@ private:
const ngraph::op::BatchNormInference* bn = const ngraph::op::BatchNormInference* bn =
static_cast<const ngraph::op::BatchNormInference*>(&node); static_cast<const ngraph::op::BatchNormInference*>(&node);
reference::batch_norm_inference<T>(bn->get_eps_value(), reference::batch_norm_inference<T>(bn->get_eps_value(),
static_cast<const T*>(args[0]), args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<const T*>(args[2]), args[2]->get_data_ptr<const T>(),
static_cast<const T*>(args[3]), args[3]->get_data_ptr<const T>(),
static_cast<const T*>(args[4]), args[4]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(2)); node.get_input_shape(2));
break; break;
} }
...@@ -402,23 +451,23 @@ private: ...@@ -402,23 +451,23 @@ private:
const ngraph::op::BatchNormTrainingBackprop* bn_bprop = const ngraph::op::BatchNormTrainingBackprop* bn_bprop =
static_cast<const ngraph::op::BatchNormTrainingBackprop*>(&node); static_cast<const ngraph::op::BatchNormTrainingBackprop*>(&node);
reference::batch_norm_backprop(bn_bprop->get_eps_value(), reference::batch_norm_backprop(bn_bprop->get_eps_value(),
static_cast<const T*>(args[0]), args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<const T*>(args[2]), args[2]->get_data_ptr<const T>(),
static_cast<const T*>(args[3]), args[3]->get_data_ptr<const T>(),
static_cast<const T*>(args[4]), args[4]->get_data_ptr<const T>(),
static_cast<const T*>(args[5]), args[5]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
static_cast<T*>(out[1]), out[1]->get_data_ptr<T>(),
static_cast<T*>(out[2]), out[2]->get_data_ptr<T>(),
node.get_input_shape(2)); node.get_input_shape(2));
break; break;
} }
case OP_TYPEID::AvgPoolBackprop: case OP_TYPEID::AvgPoolBackprop:
{ {
const op::AvgPoolBackprop* apb = static_cast<const op::AvgPoolBackprop*>(&node); const op::AvgPoolBackprop* apb = static_cast<const op::AvgPoolBackprop*>(&node);
reference::avg_pool_backprop<T>(static_cast<const T*>(args[0]), reference::avg_pool_backprop<T>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
apb->get_window_shape(), apb->get_window_shape(),
...@@ -434,8 +483,8 @@ private: ...@@ -434,8 +483,8 @@ private:
Shape in_shape = node.get_input_shape(0); Shape in_shape = node.get_input_shape(0);
Shape out_shape = node.get_output_shape(0); Shape out_shape = node.get_output_shape(0);
AxisSet broadcast_axes = broadcast->get_broadcast_axes(); AxisSet broadcast_axes = broadcast->get_broadcast_axes();
gcpu::kernel::broadcast<T>(static_cast<const T*>(args[0]), kernel::broadcast<T>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
in_shape, in_shape,
out_shape, out_shape,
broadcast_axes); broadcast_axes);
...@@ -443,23 +492,28 @@ private: ...@@ -443,23 +492,28 @@ private:
} }
case OP_TYPEID::BroadcastDistributed: case OP_TYPEID::BroadcastDistributed:
{ {
int rank_ID = get_distributed_interface()->get_rank(); const ngraph::op::BroadcastDistributed* broadcast =
if (rank_ID == 0) static_cast<const ngraph::op::BroadcastDistributed*>(&node);
int rank_ID;
rank_ID = get_distributed_interface()->get_rank();
int root_id = broadcast->get_root_id();
if (rank_ID == root_id)
{ {
reference::broadcastdistributed<T>( reference::broadcastdistributed<T>(
static_cast<T*>(args[0]), args[0]->get_data_ptr<T>(),
node.get_input_element_type(0), node.get_input_element_type(0).get_type_enum(),
static_cast<int>(shape_size(node.get_input_shape(0)))); static_cast<int>(shape_size(node.get_input_shape(0))),
auto memSize = static_cast<int>(shape_size(node.get_input_shape(0))) * root_id);
sizeof(node.get_input_element_type(0)); auto memSize = static_cast<int>(shape_size(node.get_input_shape(0))) * sizeof(T);
memcpy(out[0], args[0], memSize); memcpy(out[0]->get_data_ptr<T>(), args[0]->get_data_ptr<T>(), memSize);
} }
else else
{ {
reference::broadcastdistributed<T>( reference::broadcastdistributed<T>(
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_element_type(0), node.get_input_element_type(0).get_type_enum(),
static_cast<int>(shape_size(node.get_input_shape(0)))); static_cast<int>(shape_size(node.get_input_shape(0))),
root_id);
} }
break; break;
} }
...@@ -468,7 +522,7 @@ private: ...@@ -468,7 +522,7 @@ private:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::ceiling<T>( reference::ceiling<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Concat: case OP_TYPEID::Concat:
...@@ -478,11 +532,11 @@ private: ...@@ -478,11 +532,11 @@ private:
std::vector<Shape> in_shapes; std::vector<Shape> in_shapes;
for (size_t i = 0; i < node.get_input_size(); i++) for (size_t i = 0; i < node.get_input_size(); i++)
{ {
in_args.push_back(static_cast<const T*>(args[i])); in_args.push_back(args[i]->get_data_ptr<const T>());
in_shapes.push_back(node.get_input_shape(i)); in_shapes.push_back(node.get_input_shape(i));
} }
reference::concat<T>(in_args, reference::concat<T>(in_args,
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
in_shapes, in_shapes,
node.get_output_shape(0), node.get_output_shape(0),
concat->get_concatenation_axis()); concat->get_concatenation_axis());
...@@ -492,7 +546,7 @@ private: ...@@ -492,7 +546,7 @@ private:
{ {
const op::Constant* c = static_cast<const op::Constant*>(&node); const op::Constant* c = static_cast<const op::Constant*>(&node);
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::constant<T>(c->get_data_ptr<T>(), static_cast<T*>(out[0]), element_count); reference::constant<T>(c->get_data_ptr<T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::ScalarConstantLike: break; case OP_TYPEID::ScalarConstantLike: break;
...@@ -505,52 +559,62 @@ private: ...@@ -505,52 +559,62 @@ private:
switch (type.get_type_enum()) switch (type.get_type_enum())
{ {
case element::Type_t::boolean: case element::Type_t::boolean:
reference::convert<T>( reference::convert_to_bool<T>(
static_cast<const T*>(args[0]), static_cast<char*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<char>(), element_count);
break; break;
case element::Type_t::f32: case element::Type_t::f32:
reference::convert<T>( reference::convert<T>(
static_cast<const T*>(args[0]), static_cast<float*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<float>(), element_count);
break; break;
case element::Type_t::f64: case element::Type_t::f64:
reference::convert<T>( reference::convert<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[0]), static_cast<double*>(out[0]), element_count); out[0]->get_data_ptr<double>(),
element_count);
break; break;
case element::Type_t::i8: case element::Type_t::i8:
reference::convert<T>( reference::convert<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[0]), static_cast<int8_t*>(out[0]), element_count); out[0]->get_data_ptr<int8_t>(),
element_count);
break; break;
case element::Type_t::i16: case element::Type_t::i16:
reference::convert<T>( reference::convert<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[0]), static_cast<int16_t*>(out[0]), element_count); out[0]->get_data_ptr<int16_t>(),
element_count);
break; break;
case element::Type_t::i32: case element::Type_t::i32:
reference::convert<T>( reference::convert<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[0]), static_cast<int32_t*>(out[0]), element_count); out[0]->get_data_ptr<int32_t>(),
element_count);
break; break;
case element::Type_t::i64: case element::Type_t::i64:
reference::convert<T>( reference::convert<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[0]), static_cast<int64_t*>(out[0]), element_count); out[0]->get_data_ptr<int64_t>(),
element_count);
break; break;
case element::Type_t::u8: case element::Type_t::u8:
reference::convert<T>( reference::convert<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[0]), static_cast<uint8_t*>(out[0]), element_count); out[0]->get_data_ptr<uint8_t>(),
element_count);
break; break;
case element::Type_t::u16: case element::Type_t::u16:
reference::convert<T>( reference::convert<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[0]), static_cast<uint16_t*>(out[0]), element_count); out[0]->get_data_ptr<uint16_t>(),
element_count);
break; break;
case element::Type_t::u32: case element::Type_t::u32:
reference::convert<T>( reference::convert<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[0]), static_cast<uint32_t*>(out[0]), element_count); out[0]->get_data_ptr<uint32_t>(),
element_count);
break; break;
case element::Type_t::u64: case element::Type_t::u64:
reference::convert<T>( reference::convert<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[0]), static_cast<uint64_t*>(out[0]), element_count); out[0]->get_data_ptr<uint64_t>(),
element_count);
break; break;
case element::Type_t::undefined: case element::Type_t::undefined:
case element::Type_t::dynamic: case element::Type_t::dynamic:
case element::Type_t::bf16: case element::Type_t::bf16:
case element::Type_t::f16:
ss << "unsupported element type " << type << " op Convert"; ss << "unsupported element type " << type << " op Convert";
throw std::runtime_error(ss.str()); throw std::runtime_error(ss.str());
} }
...@@ -559,9 +623,9 @@ private: ...@@ -559,9 +623,9 @@ private:
case OP_TYPEID::Convolution: case OP_TYPEID::Convolution:
{ {
const op::Convolution* c = static_cast<const op::Convolution*>(&node); const op::Convolution* c = static_cast<const op::Convolution*>(&node);
reference::convolution<T>(static_cast<const T*>(args[0]), reference::convolution<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_input_shape(1), node.get_input_shape(1),
node.get_output_shape(0), node.get_output_shape(0),
...@@ -569,38 +633,26 @@ private: ...@@ -569,38 +633,26 @@ private:
c->get_window_dilation_strides(), c->get_window_dilation_strides(),
c->get_padding_below(), c->get_padding_below(),
c->get_padding_above(), c->get_padding_above(),
c->get_data_dilation_strides(), c->get_data_dilation_strides());
0,
1,
1,
0,
0,
1,
false);
break; break;
} }
case OP_TYPEID::ConvolutionBackpropFilters: case OP_TYPEID::ConvolutionBackpropFilters:
{ {
const op::ConvolutionBackpropFilters* c = const op::ConvolutionBackpropFilters* c =
static_cast<const op::ConvolutionBackpropFilters*>(&node); static_cast<const op::ConvolutionBackpropFilters*>(&node);
reference::convolution<T>(static_cast<const T*>(args[0]), reference::convolution_backprop_filter<T>(
static_cast<const T*>(args[1]), args[0]->get_data_ptr<const T>(), // input
static_cast<T*>(out[0]), args[1]->get_data_ptr<const T>(), // delta_convolution_output
node.get_input_shape(0), out[0]->get_data_ptr<T>(), // delta_filter
node.get_input_shape(1), c->get_input_shape(0), // input_shape
node.get_output_shape(0), c->get_input_shape(1), // convolution_output_shape
c->get_window_movement_strides_backward(), c->get_filters_shape(), // filter_shape
c->get_window_dilation_strides_backward(), c->get_window_dilation_strides_forward(),
c->get_padding_below_backward(), c->get_window_movement_strides_forward(),
c->get_padding_above_backward(), c->get_padding_below_forward(),
c->get_data_dilation_strides_backward(), c->compute_backward_in_pad_above(),
1, c->get_data_dilation_strides_forward());
0,
0,
1,
1,
0,
false);
break; break;
} }
case OP_TYPEID::ConvolutionBackpropData: case OP_TYPEID::ConvolutionBackpropData:
...@@ -608,38 +660,31 @@ private: ...@@ -608,38 +660,31 @@ private:
// Note that args[1] and args[0] are switched here from the usual order. // Note that args[1] and args[0] are switched here from the usual order.
const op::ConvolutionBackpropData* c = const op::ConvolutionBackpropData* c =
static_cast<const op::ConvolutionBackpropData*>(&node); static_cast<const op::ConvolutionBackpropData*>(&node);
reference::convolution<T>(static_cast<const T*>(args[1]), reference::convolution_backprop_in<T>(args[1]->get_data_ptr<const T>(),
static_cast<const T*>(args[0]), args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(1), c->get_input_shape(1),
node.get_input_shape(0), c->get_input_shape(0),
node.get_output_shape(0), c->get_data_batch_shape(),
c->get_window_movement_strides_backward(), c->get_data_dilation_strides_forward(),
c->get_window_dilation_strides_backward(), c->get_window_dilation_strides_forward(),
c->get_padding_below_backward(), c->compute_backward_delta_out_pad_below(),
c->get_padding_above_backward(), c->compute_backward_delta_out_pad_above(),
c->get_data_dilation_strides_backward(), c->get_window_movement_strides_forward());
0,
1,
0,
1,
0,
1,
true);
break; break;
} }
case OP_TYPEID::Cos: case OP_TYPEID::Cos:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::cos<T>( reference::cos<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Cosh: case OP_TYPEID::Cosh:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::cosh<T>( reference::cosh<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Dequantize: case OP_TYPEID::Dequantize:
...@@ -649,20 +694,20 @@ private: ...@@ -649,20 +694,20 @@ private:
if (type == element::f32) if (type == element::f32)
{ {
reference::dequantize<T>(static_cast<const T*>(args[0]), reference::dequantize<T>(args[0]->get_data_ptr<const T>(),
static_cast<const float*>(args[1]), args[1]->get_data_ptr<const float>(),
static_cast<const T*>(args[2]), args[2]->get_data_ptr<const T>(),
static_cast<float*>(out[0]), out[0]->get_data_ptr<float>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_input_shape(1), node.get_input_shape(1),
dequantize->get_axes()); dequantize->get_axes());
} }
else if (type == element::f64) else if (type == element::f64)
{ {
reference::dequantize<T>(static_cast<const T*>(args[0]), reference::dequantize<T>(args[0]->get_data_ptr<const T>(),
static_cast<const double*>(args[1]), args[1]->get_data_ptr<const double>(),
static_cast<const T*>(args[2]), args[2]->get_data_ptr<const T>(),
static_cast<double*>(out[0]), out[0]->get_data_ptr<double>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_input_shape(1), node.get_input_shape(1),
dequantize->get_axes()); dequantize->get_axes());
...@@ -680,9 +725,9 @@ private: ...@@ -680,9 +725,9 @@ private:
{ {
const op::Divide* divop = static_cast<const op::Divide*>(&node); const op::Divide* divop = static_cast<const op::Divide*>(&node);
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::divide<T>(static_cast<const T*>(args[0]), reference::divide<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count, element_count,
divop->is_pythondiv()); divop->is_pythondiv());
break; break;
...@@ -691,15 +736,25 @@ private: ...@@ -691,15 +736,25 @@ private:
{ {
const op::Dot* dot = static_cast<const op::Dot*>(&node); const op::Dot* dot = static_cast<const op::Dot*>(&node);
gcpu::kernel::dot(static_cast<const T*>(args[0]), kernel::dot(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_input_shape(1), node.get_input_shape(1),
node.get_output_shape(0), node.get_output_shape(0),
dot->get_reduction_axes_count()); dot->get_reduction_axes_count());
break; break;
} }
case OP_TYPEID::DynReshape:
{
throw unsupported_op("Unsupported op '" + node.description() + "'");
break;
}
case OP_TYPEID::DynSlice:
{
throw unsupported_op("Unsupported op '" + node.description() + "'");
break;
}
case OP_TYPEID::EmbeddingLookup: case OP_TYPEID::EmbeddingLookup:
{ {
const op::EmbeddingLookup* embed = static_cast<const op::EmbeddingLookup*>(&node); const op::EmbeddingLookup* embed = static_cast<const op::EmbeddingLookup*>(&node);
...@@ -708,33 +763,33 @@ private: ...@@ -708,33 +763,33 @@ private:
if (type == element::f32) if (type == element::f32)
{ {
reference::embedding<T, float>(static_cast<const float*>(args[0]), reference::embedding<T, float>(args[0]->get_data_ptr<const float>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count, element_count,
embed->get_shape()); embed->get_shape());
} }
else if (type == element::f64) else if (type == element::f64)
{ {
reference::embedding<T, double>(static_cast<const double*>(args[0]), reference::embedding<T, double>(args[0]->get_data_ptr<const double>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count, element_count,
embed->get_shape()); embed->get_shape());
} }
else if (type == element::i32) else if (type == element::i32)
{ {
reference::embedding<T, int>(static_cast<const int*>(args[0]), reference::embedding<T, int32_t>(args[0]->get_data_ptr<const int>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count, element_count,
embed->get_shape()); embed->get_shape());
} }
else if (type == element::i64) else if (type == element::i64)
{ {
reference::embedding<T, int64_t>(static_cast<const int64_t*>(args[0]), reference::embedding<T, int64_t>(args[0]->get_data_ptr<const int64_t>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count, element_count,
embed->get_shape()); embed->get_shape());
} }
...@@ -748,24 +803,56 @@ private: ...@@ -748,24 +803,56 @@ private:
case OP_TYPEID::Equal: case OP_TYPEID::Equal:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::equal<T>(static_cast<const T*>(args[0]), reference::equal<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<char*>(out[0]), out[0]->get_data_ptr<char>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::Erf:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::erf<T>(
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Exp: case OP_TYPEID::Exp:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::exp<T>( reference::exp<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
#ifdef INTERPRETER_USE_HYBRID
case OP_TYPEID::FunctionCall:
{
auto f = static_cast<const runtime::hybrid::op::FunctionCall*>(&node);
auto backend = f->get_backend();
auto executable = f->get_executable();
std::vector<std::shared_ptr<Tensor>> outputs;
std::vector<std::shared_ptr<Tensor>> inputs;
for (const std::shared_ptr<HostTensor>& t : out)
{
auto backend_tensor = backend->create_tensor(
t->get_element_type(), t->get_shape(), t->get_data_ptr());
outputs.push_back(backend_tensor);
}
for (const std::shared_ptr<HostTensor>& t : args)
{
auto backend_tensor = backend->create_tensor(
t->get_element_type(), t->get_shape(), t->get_data_ptr());
inputs.push_back(backend_tensor);
}
executable->call(outputs, inputs);
break; break;
} }
#endif
case OP_TYPEID::Floor: case OP_TYPEID::Floor:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::floor<T>( reference::floor<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Gather: case OP_TYPEID::Gather:
...@@ -826,36 +913,36 @@ private: ...@@ -826,36 +913,36 @@ private:
case OP_TYPEID::Greater: case OP_TYPEID::Greater:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::greater<T>(static_cast<const T*>(args[0]), reference::greater<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<char*>(out[0]), out[0]->get_data_ptr<char>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::GreaterEq: case OP_TYPEID::GreaterEq:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::greater_eq<T>(static_cast<const T*>(args[0]), reference::greater_eq<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<char*>(out[0]), out[0]->get_data_ptr<char>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::Less: case OP_TYPEID::Less:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::less<T>(static_cast<const T*>(args[0]), reference::less<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<char*>(out[0]), out[0]->get_data_ptr<char>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::LessEq: case OP_TYPEID::LessEq:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::less_eq<T>(static_cast<const T*>(args[0]), reference::less_eq<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<char*>(out[0]), out[0]->get_data_ptr<char>(),
element_count); element_count);
break; break;
} }
...@@ -863,14 +950,14 @@ private: ...@@ -863,14 +950,14 @@ private:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::log<T>( reference::log<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::LRN: case OP_TYPEID::LRN:
{ {
const op::LRN* lrn = static_cast<const op::LRN*>(&node); const op::LRN* lrn = static_cast<const op::LRN*>(&node);
reference::lrn<T>(static_cast<const T*>(args[0]), reference::lrn<T>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
lrn->get_alpha(), lrn->get_alpha(),
lrn->get_beta(), lrn->get_beta(),
...@@ -881,8 +968,8 @@ private: ...@@ -881,8 +968,8 @@ private:
case OP_TYPEID::Max: case OP_TYPEID::Max:
{ {
const op::Max* max = static_cast<const op::Max*>(&node); const op::Max* max = static_cast<const op::Max*>(&node);
reference::max<T>(static_cast<const T*>(args[0]), reference::max<T>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
max->get_reduction_axes()); max->get_reduction_axes());
...@@ -891,9 +978,9 @@ private: ...@@ -891,9 +978,9 @@ private:
case OP_TYPEID::Maximum: case OP_TYPEID::Maximum:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::maximum<T>(static_cast<const T*>(args[0]), reference::maximum<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count); element_count);
break; break;
} }
...@@ -901,8 +988,8 @@ private: ...@@ -901,8 +988,8 @@ private:
{ {
const op::MaxPool* max_pool = static_cast<const op::MaxPool*>(&node); const op::MaxPool* max_pool = static_cast<const op::MaxPool*>(&node);
reference::max_pool<T>(static_cast<const T*>(args[0]), reference::max_pool<T>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
max_pool->get_window_shape(), max_pool->get_window_shape(),
...@@ -916,9 +1003,9 @@ private: ...@@ -916,9 +1003,9 @@ private:
const op::MaxPoolBackprop* max_pool_backprop = const op::MaxPoolBackprop* max_pool_backprop =
static_cast<const op::MaxPoolBackprop*>(&node); static_cast<const op::MaxPoolBackprop*>(&node);
reference::max_pool_backprop<T>(static_cast<const T*>(args[0]), reference::max_pool_backprop<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(1), node.get_input_shape(1),
node.get_output_shape(0), node.get_output_shape(0),
max_pool_backprop->get_window_shape(), max_pool_backprop->get_window_shape(),
...@@ -930,8 +1017,8 @@ private: ...@@ -930,8 +1017,8 @@ private:
case OP_TYPEID::Min: case OP_TYPEID::Min:
{ {
const op::Min* min = static_cast<const op::Min*>(&node); const op::Min* min = static_cast<const op::Min*>(&node);
reference::min<T>(static_cast<const T*>(args[0]), reference::min<T>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
min->get_reduction_axes()); min->get_reduction_axes());
...@@ -940,18 +1027,18 @@ private: ...@@ -940,18 +1027,18 @@ private:
case OP_TYPEID::Minimum: case OP_TYPEID::Minimum:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::minimum<T>(static_cast<const T*>(args[0]), reference::minimum<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::Multiply: case OP_TYPEID::Multiply:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::multiply<T>(static_cast<const T*>(args[0]), reference::multiply<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count); element_count);
break; break;
} }
...@@ -959,30 +1046,30 @@ private: ...@@ -959,30 +1046,30 @@ private:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::negate<T>( reference::negate<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Not: case OP_TYPEID::Not:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::logical_not( reference::logical_not(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::NotEqual: case OP_TYPEID::NotEqual:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::not_equal<T>(static_cast<const T*>(args[0]), reference::not_equal<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<char*>(out[0]), out[0]->get_data_ptr<char>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::OneHot: case OP_TYPEID::OneHot:
{ {
const op::OneHot* oh = static_cast<const op::OneHot*>(&node); const op::OneHot* oh = static_cast<const op::OneHot*>(&node);
reference::one_hot<T>(static_cast<const T*>(args[0]), reference::one_hot<T>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
oh->get_one_hot_axis()); oh->get_one_hot_axis());
...@@ -991,46 +1078,46 @@ private: ...@@ -991,46 +1078,46 @@ private:
case OP_TYPEID::Or: case OP_TYPEID::Or:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::logical_or(static_cast<const T*>(args[0]), reference::logical_or(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::Parameter: break; case OP_TYPEID::Parameter: break;
case OP_TYPEID::Passthrough:
{
const op::Passthrough* passthrough = static_cast<const op::Passthrough*>(&node);
throw unsupported_op{"Unsupported operation language: " + passthrough->language()};
}
case OP_TYPEID::Pad: case OP_TYPEID::Pad:
{ {
const op::Pad* pad = static_cast<const op::Pad*>(&node); const op::Pad* pad = static_cast<const op::Pad*>(&node);
reference::pad(static_cast<const T*>(args[0]), reference::pad(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_inputs().at(0).get_shape(), node.input(0).get_shape(),
node.get_output_shape(0), node.output(0).get_shape(),
pad->get_padding_below(), pad->get_padding_below(),
pad->get_padding_above(), pad->get_padding_above(),
pad->get_padding_interior()); pad->get_pad_mode());
break; break;
} }
case OP_TYPEID::Passthrough:
{
const op::Passthrough* passthrough = static_cast<const op::Passthrough*>(&node);
throw unsupported_op{"Unsupported operation language: " + passthrough->language()};
}
case OP_TYPEID::Power: case OP_TYPEID::Power:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::power<T>(static_cast<const T*>(args[0]), reference::power<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::Product: case OP_TYPEID::Product:
{ {
const op::Product* product = static_cast<const op::Product*>(&node); const op::Product* product = static_cast<const op::Product*>(&node);
reference::product<T>(static_cast<const T*>(args[0]), reference::product<T>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
product->get_reduction_axes()); product->get_reduction_axes());
...@@ -1043,10 +1130,10 @@ private: ...@@ -1043,10 +1130,10 @@ private:
if (type == element::u8) if (type == element::u8)
{ {
reference::quantize<T>(static_cast<const T*>(args[0]), reference::quantize<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<const uint8_t*>(args[2]), args[2]->get_data_ptr<const uint8_t>(),
static_cast<uint8_t*>(out[0]), out[0]->get_data_ptr<uint8_t>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_input_shape(1), node.get_input_shape(1),
quantize->get_axes(), quantize->get_axes(),
...@@ -1054,10 +1141,10 @@ private: ...@@ -1054,10 +1141,10 @@ private:
} }
else if (type == element::i8) else if (type == element::i8)
{ {
reference::quantize<T>(static_cast<const T*>(args[0]), reference::quantize<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<const int8_t*>(args[2]), args[2]->get_data_ptr<const int8_t>(),
static_cast<int8_t*>(out[0]), out[0]->get_data_ptr<int8_t>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_input_shape(1), node.get_input_shape(1),
quantize->get_axes(), quantize->get_axes(),
...@@ -1065,10 +1152,10 @@ private: ...@@ -1065,10 +1152,10 @@ private:
} }
else if (type == element::i32) else if (type == element::i32)
{ {
reference::quantize<T>(static_cast<const T*>(args[0]), reference::quantize<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<const int32_t*>(args[2]), args[2]->get_data_ptr<const int32_t>(),
static_cast<int32_t*>(out[0]), out[0]->get_data_ptr<int32_t>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_input_shape(1), node.get_input_shape(1),
quantize->get_axes(), quantize->get_axes(),
...@@ -1083,40 +1170,168 @@ private: ...@@ -1083,40 +1170,168 @@ private:
break; break;
} }
case OP_TYPEID::QuantizedConvolution:
{
const op::QuantizedConvolution* qc =
static_cast<const op::QuantizedConvolution*>(&node);
auto input_element_type = qc->get_input_element_type(0);
auto filter_element_type = qc->get_input_element_type(1);
auto output_element_type = qc->get_output_element_type(0);
if (input_element_type == element::u8 && filter_element_type == element::i8 &&
output_element_type == element::i8)
{
reference::convolution<uint8_t, int8_t, int8_t, int32_t>(
args[0]->get_data_ptr<const uint8_t>(),
args[1]->get_data_ptr<const int8_t>(),
out[0]->get_data_ptr<int8_t>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
qc->get_window_movement_strides(),
qc->get_window_dilation_strides(),
qc->get_padding_below(),
qc->get_padding_above(),
qc->get_data_dilation_strides(),
args[2]->get_data_ptr<const float>(),
args[3]->get_data_ptr<const uint8_t>(),
args[4]->get_data_ptr<const float>(),
args[5]->get_data_ptr<const int8_t>(),
args[6]->get_data_ptr<const float>(),
args[7]->get_data_ptr<const int8_t>());
}
else if (input_element_type == element::u8 && filter_element_type == element::u8 &&
output_element_type == element::u8)
{
reference::convolution<uint8_t, uint8_t, uint8_t, int32_t>(
args[0]->get_data_ptr<const uint8_t>(),
args[1]->get_data_ptr<const uint8_t>(),
out[0]->get_data_ptr<uint8_t>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
qc->get_window_movement_strides(),
qc->get_window_dilation_strides(),
qc->get_padding_below(),
qc->get_padding_above(),
qc->get_data_dilation_strides(),
args[2]->get_data_ptr<const float>(),
args[3]->get_data_ptr<const uint8_t>(),
args[4]->get_data_ptr<const float>(),
args[5]->get_data_ptr<const uint8_t>(),
args[6]->get_data_ptr<const float>(),
args[7]->get_data_ptr<const uint8_t>());
}
else if (input_element_type == element::u8 && filter_element_type == element::i8 &&
output_element_type == element::i32)
{
reference::convolution<uint8_t, int8_t, int32_t, int32_t>(
args[0]->get_data_ptr<const uint8_t>(),
args[1]->get_data_ptr<const int8_t>(),
out[0]->get_data_ptr<int32_t>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
qc->get_window_movement_strides(),
qc->get_window_dilation_strides(),
qc->get_padding_below(),
qc->get_padding_above(),
qc->get_data_dilation_strides(),
args[2]->get_data_ptr<const float>(),
args[3]->get_data_ptr<const uint8_t>(),
args[4]->get_data_ptr<const float>(),
args[5]->get_data_ptr<const int8_t>(),
args[6]->get_data_ptr<const float>(),
args[7]->get_data_ptr<const int32_t>());
}
else if (input_element_type == element::u8 && filter_element_type == element::u8 &&
output_element_type == element::i32)
{
reference::convolution<uint8_t, uint8_t, int32_t, int32_t>(
args[0]->get_data_ptr<const uint8_t>(),
args[1]->get_data_ptr<const uint8_t>(),
out[0]->get_data_ptr<int32_t>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
qc->get_window_movement_strides(),
qc->get_window_dilation_strides(),
qc->get_padding_below(),
qc->get_padding_above(),
qc->get_data_dilation_strides(),
args[2]->get_data_ptr<const float>(),
args[3]->get_data_ptr<const uint8_t>(),
args[4]->get_data_ptr<const float>(),
args[5]->get_data_ptr<const uint8_t>(),
args[6]->get_data_ptr<const float>(),
args[7]->get_data_ptr<const int32_t>());
}
else
{
std::stringstream ss;
ss << "unsupported element type";
throw std::runtime_error(ss.str());
}
break;
}
case OP_TYPEID::QuantizedAvgPool: case OP_TYPEID::QuantizedAvgPool:
case OP_TYPEID::QuantizedConvolutionBias: case OP_TYPEID::QuantizedConvolutionBias:
case OP_TYPEID::QuantizedConvolutionBiasAdd: case OP_TYPEID::QuantizedConvolutionBiasAdd:
case OP_TYPEID::QuantizedConvolutionBiasSignedAdd: case OP_TYPEID::QuantizedConvolutionBiasSignedAdd:
case OP_TYPEID::QuantizedConvolutionRelu: case OP_TYPEID::QuantizedConvolutionRelu:
case OP_TYPEID::QuantizedConvolution:
case OP_TYPEID::QuantizedMaxPool: case OP_TYPEID::QuantizedMaxPool:
case OP_TYPEID::QuantizedDotBias: case OP_TYPEID::QuantizedDotBias:
case OP_TYPEID::QuantizedDot: case OP_TYPEID::QuantizedDot:
{ {
throw unsupported_op("Unsupported op '" + node.description() + "'."); throw unsupported_op("Unsupported op '" + node.description() +
"' in Interpreter back end.");
}
case OP_TYPEID::Recv:
{
size_t element_count = shape_size(node.get_output_shape(0));
size_t memSize = element_count * sizeof(T);
const auto* op = static_cast<const ngraph::op::Recv*>(&node);
int src_id = op->get_src_id();
reference::recv<T>(args[0]->get_data_ptr<T>(),
node.get_input_element_type(0).get_type_enum(),
element_count,
src_id);
memcpy(out[0]->get_data_ptr<T>(), args[0]->get_data_ptr<T>(), memSize);
break;
}
case OP_TYPEID::Range:
{
throw unsupported_op("Unsupported op '" + node.description() + "'");
break;
} }
case OP_TYPEID::Relu: case OP_TYPEID::Relu:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::relu<T>( reference::relu<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::ReluBackprop: case OP_TYPEID::ReluBackprop:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::relu_backprop<T>(static_cast<const T*>(args[0]), reference::relu_backprop<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::ReplaceSlice: case OP_TYPEID::ReplaceSlice:
{ {
const op::ReplaceSlice* slice = static_cast<const op::ReplaceSlice*>(&node); const op::ReplaceSlice* slice = static_cast<const op::ReplaceSlice*>(&node);
reference::replace_slice<T>(static_cast<const T*>(args[0]), reference::replace_slice<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(1), node.get_input_shape(1),
slice->get_lower_bounds(), slice->get_lower_bounds(),
slice->get_upper_bounds(), slice->get_upper_bounds(),
...@@ -1127,8 +1342,8 @@ private: ...@@ -1127,8 +1342,8 @@ private:
case OP_TYPEID::Reshape: case OP_TYPEID::Reshape:
{ {
const op::Reshape* reshape = static_cast<const op::Reshape*>(&node); const op::Reshape* reshape = static_cast<const op::Reshape*>(&node);
gcpu::kernel::reshape(static_cast<const T*>(args[0]), kernel::reshape(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
reshape->get_input_order(), reshape->get_input_order(),
node.get_output_shape(0)); node.get_output_shape(0));
...@@ -1137,16 +1352,16 @@ private: ...@@ -1137,16 +1352,16 @@ private:
case OP_TYPEID::Result: case OP_TYPEID::Result:
{ {
const op::Result* res = static_cast<const op::Result*>(&node); const op::Result* res = static_cast<const op::Result*>(&node);
reference::result(static_cast<const T*>(args[0]), reference::result(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
shape_size(res->get_shape())); shape_size(res->get_shape()));
break; break;
} }
case OP_TYPEID::Reverse: case OP_TYPEID::Reverse:
{ {
const op::Reverse* reverse = static_cast<const op::Reverse*>(&node); const op::Reverse* reverse = static_cast<const op::Reverse*>(&node);
reference::reverse(static_cast<const T*>(args[0]), reference::reverse(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
reverse->get_reversed_axes()); reverse->get_reversed_axes());
...@@ -1158,12 +1373,12 @@ private: ...@@ -1158,12 +1373,12 @@ private:
if (node.get_input_element_type(1) == element::i32) if (node.get_input_element_type(1) == element::i32)
{ {
reference::reverse_sequence<T, int32_t>(static_cast<const T*>(args[0]), reference::reverse_sequence<T, int32_t>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
reverse->get_batch_axis(), reverse->get_batch_axis(),
reverse->get_sequence_axis(), reverse->get_sequence_axis(),
static_cast<const int32_t*>(args[1])); args[1]->get_data_ptr<const int32_t>());
} }
else else
{ {
...@@ -1234,31 +1449,46 @@ private: ...@@ -1234,31 +1449,46 @@ private:
case OP_TYPEID::Select: case OP_TYPEID::Select:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::select<T>(static_cast<const char*>(args[0]), reference::select<T>(args[0]->get_data_ptr<const char>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<const T*>(args[2]), args[2]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::Send:
{
size_t element_count = shape_size(node.get_output_shape(0));
size_t memSize = element_count * sizeof(T);
const auto* op = static_cast<const ngraph::op::Send*>(&node);
int dest_id = op->get_dest_id();
reference::send<T>(args[0]->get_data_ptr<const T>(),
node.get_input_element_type(0).get_type_enum(),
element_count,
dest_id);
memcpy(out[0]->get_data_ptr<T>(), args[0]->get_data_ptr<T>(), memSize);
break;
}
case OP_TYPEID::ShapeOf: case OP_TYPEID::ShapeOf:
{ {
reference::shape_of(node.get_input_shape(0), static_cast<uint64_t*>(out[0])); reference::shape_of(node.get_input_shape(0), out[0]->get_data_ptr<uint64_t>());
break; break;
} }
case OP_TYPEID::Sigmoid: case OP_TYPEID::Sigmoid:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::sigmoid<T>( reference::sigmoid<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::SigmoidBackprop: case OP_TYPEID::SigmoidBackprop:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::sigmoid_backprop<T>(static_cast<const T*>(args[0]), reference::sigmoid_backprop<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count); element_count);
break; break;
} }
...@@ -1266,28 +1496,28 @@ private: ...@@ -1266,28 +1496,28 @@ private:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::sign<T>( reference::sign<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Sin: case OP_TYPEID::Sin:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::sin<T>( reference::sin<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Sinh: case OP_TYPEID::Sinh:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::sinh<T>( reference::sinh<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Slice: case OP_TYPEID::Slice:
{ {
const op::Slice* slice = static_cast<const op::Slice*>(&node); const op::Slice* slice = static_cast<const op::Slice*>(&node);
reference::slice<T>(static_cast<const T*>(args[0]), reference::slice<T>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
slice->get_lower_bounds(), slice->get_lower_bounds(),
slice->get_upper_bounds(), slice->get_upper_bounds(),
...@@ -1298,8 +1528,8 @@ private: ...@@ -1298,8 +1528,8 @@ private:
case OP_TYPEID::Softmax: case OP_TYPEID::Softmax:
{ {
const op::Softmax* softmax = static_cast<const op::Softmax*>(&node); const op::Softmax* softmax = static_cast<const op::Softmax*>(&node);
reference::softmax<T>(static_cast<const T*>(args[0]), reference::softmax<T>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_output_shape(0), node.get_output_shape(0),
softmax->get_axes()); softmax->get_axes());
break; break;
...@@ -1308,7 +1538,7 @@ private: ...@@ -1308,7 +1538,7 @@ private:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::sqrt<T>( reference::sqrt<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::StopGradient: { throw unsupported_op("Unsupported op 'StopGradient'"); case OP_TYPEID::StopGradient: { throw unsupported_op("Unsupported op 'StopGradient'");
...@@ -1316,17 +1546,17 @@ private: ...@@ -1316,17 +1546,17 @@ private:
case OP_TYPEID::Subtract: case OP_TYPEID::Subtract:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::subtract<T>(static_cast<const T*>(args[0]), reference::subtract<T>(args[0]->get_data_ptr<const T>(),
static_cast<const T*>(args[1]), args[1]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
element_count); element_count);
break; break;
} }
case OP_TYPEID::Sum: case OP_TYPEID::Sum:
{ {
const op::Sum* sum = static_cast<const op::Sum*>(&node); const op::Sum* sum = static_cast<const op::Sum*>(&node);
reference::sum<T>(static_cast<const T*>(args[0]), reference::sum<T>(args[0]->get_data_ptr<const T>(),
static_cast<T*>(out[0]), out[0]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
sum->get_reduction_axes()); sum->get_reduction_axes());
...@@ -1336,14 +1566,14 @@ private: ...@@ -1336,14 +1566,14 @@ private:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::tan<T>( reference::tan<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::Tanh: case OP_TYPEID::Tanh:
{ {
size_t element_count = shape_size(node.get_output_shape(0)); size_t element_count = shape_size(node.get_output_shape(0));
reference::tanh<T>( reference::tanh<T>(
static_cast<const T*>(args[0]), static_cast<T*>(out[0]), element_count); args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break; break;
} }
case OP_TYPEID::TopK: case OP_TYPEID::TopK:
...@@ -1351,9 +1581,9 @@ private: ...@@ -1351,9 +1581,9 @@ private:
const op::TopK* topk = static_cast<const op::TopK*>(&node); const op::TopK* topk = static_cast<const op::TopK*>(&node);
if (node.get_output_element_type(0) == element::i64) if (node.get_output_element_type(0) == element::i64)
{ {
reference::topk<T, int64_t>(static_cast<const T*>(args[0]), reference::topk<T, int64_t>(args[0]->get_data_ptr<const T>(),
static_cast<int64_t*>(out[0]), out[0]->get_data_ptr<int64_t>(),
static_cast<T*>(out[1]), out[1]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
topk->get_top_k_axis(), topk->get_top_k_axis(),
...@@ -1362,9 +1592,9 @@ private: ...@@ -1362,9 +1592,9 @@ private:
} }
else if (node.get_output_element_type(0) == element::i32) else if (node.get_output_element_type(0) == element::i32)
{ {
reference::topk<T, int32_t>(static_cast<const T*>(args[0]), reference::topk<T, int32_t>(args[0]->get_data_ptr<const T>(),
static_cast<int32_t*>(out[0]), out[0]->get_data_ptr<int32_t>(),
static_cast<T*>(out[1]), out[1]->get_data_ptr<T>(),
node.get_input_shape(0), node.get_input_shape(0),
node.get_output_shape(0), node.get_output_shape(0),
topk->get_top_k_axis(), topk->get_top_k_axis(),
...@@ -1377,7 +1607,12 @@ private: ...@@ -1377,7 +1607,12 @@ private:
} }
break; break;
} }
default: throw unsupported_op("Unsupported op '" + node.description() + "'"); case OP_TYPEID::DynBroadcast:
case OP_TYPEID::Transpose:
case OP_TYPEID::DynPad:
case OP_TYPEID::Tile:
case OP_TYPEID::DynReplaceSlice:
throw unsupported_op("Unsupported op '" + node.description() + "'");
#if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8)) #if !(defined(__GNUC__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 8))
#pragma GCC diagnostic pop #pragma GCC diagnostic pop
#endif #endif
......
...@@ -140,6 +140,91 @@ namespace ngraph ...@@ -140,6 +140,91 @@ namespace ngraph
} }
} }
template <typename T>
void broadcast_5d(const T* in,
T* out,
const Shape& in_shape,
const Shape& out_shape,
const AxisSet& broadcast_axes)
{
size_t index[5];
size_t* out_index = 0;
for (size_t i = 0; i < 5; i++)
{
if (broadcast_axes.count(i) == 0)
{
out_index = &index[i];
break;
}
}
for (index[0] = 0; index[0] < out_shape[0]; ++index[0])
{
for (index[1] = 0; index[1] < out_shape[1]; ++index[1])
{
for (index[2] = 0; index[2] < out_shape[2]; ++index[2])
{
for (index[3] = 0; index[3] < out_shape[3]; ++index[3])
{
for (index[4] = 0; index[4] < out_shape[4]; ++index[4])
{
out[index[0] * out_shape[1] * out_shape[2] * out_shape[3] *
out_shape[4] +
index[1] * out_shape[2] * out_shape[3] * out_shape[4] +
index[2] * out_shape[3] * out_shape[4] +
index[3] * out_shape[4] + index[4]] = in[*out_index];
}
}
}
}
}
}
template <typename T>
void broadcast_6d(const T* in,
T* out,
const Shape& in_shape,
const Shape& out_shape,
const AxisSet& broadcast_axes)
{
size_t index[6];
size_t* out_index = 0;
for (size_t i = 0; i < 6; i++)
{
if (broadcast_axes.count(i) == 0)
{
out_index = &index[i];
break;
}
}
for (index[0] = 0; index[0] < out_shape[0]; ++index[0])
{
for (index[1] = 0; index[1] < out_shape[1]; ++index[1])
{
for (index[2] = 0; index[2] < out_shape[2]; ++index[2])
{
for (index[3] = 0; index[3] < out_shape[3]; ++index[3])
{
for (index[4] = 0; index[4] < out_shape[4]; ++index[4])
{
for (index[5] = 0; index[5] < out_shape[5]; ++index[5])
{
out[index[0] * out_shape[1] * out_shape[2] *
out_shape[3] * out_shape[4] * out_shape[5] +
index[1] * out_shape[2] * out_shape[3] *
out_shape[4] * out_shape[5] +
index[2] * out_shape[3] * out_shape[4] *
out_shape[5] +
index[3] * out_shape[4] * out_shape[5] +
index[4] * out_shape[5] + index[5]] =
in[*out_index];
}
}
}
}
}
}
}
template <typename T> template <typename T>
void broadcast(const T* in, void broadcast(const T* in,
T* out, T* out,
...@@ -167,6 +252,16 @@ namespace ngraph ...@@ -167,6 +252,16 @@ namespace ngraph
case 4: case 4:
broadcast_4d<T>(in, out, in_shape, out_shape, broadcast_axes); broadcast_4d<T>(in, out, in_shape, out_shape, broadcast_axes);
break; break;
case 5:
broadcast_5d<T>(in, out, in_shape, out_shape, broadcast_axes);
break;
case 6:
broadcast_6d<T>(in, out, in_shape, out_shape, broadcast_axes);
break;
default:
runtime::reference::broadcast<T>(
in, out, in_shape, out_shape, broadcast_axes);
break;
} }
} }
else else
......
...@@ -244,10 +244,7 @@ namespace ngraph ...@@ -244,10 +244,7 @@ namespace ngraph
case 4: reshape_in4<T>(in, out, in_shape, in_axis_order, out_shape); break; case 4: reshape_in4<T>(in, out, in_shape, in_axis_order, out_shape); break;
case 5: reshape_in5<T>(in, out, in_shape, in_axis_order, out_shape); break; case 5: reshape_in5<T>(in, out, in_shape, in_axis_order, out_shape); break;
case 6: reshape_in6<T>(in, out, in_shape, in_axis_order, out_shape); break; case 6: reshape_in6<T>(in, out, in_shape, in_axis_order, out_shape); break;
default: default: reference::reshape(in, out, in_shape, in_axis_order, out_shape); break;
NGRAPH_INFO << "reference::reshape";
reference::reshape(in, out, in_shape, in_axis_order, out_shape);
break;
} }
} }
} }
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <algorithm>
#include <cmath>
#include <numeric>
#include <vector>
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace runtime
{
namespace gcpu
{
namespace kernel
{
template <typename T>
void result(const T* arg, T* out, size_t count)
{
memcpy(out, arg, sizeof(T) * count);
}
}
}
}
}
...@@ -51,7 +51,7 @@ class ngraph::runtime::gcpu::NodeWrapper ...@@ -51,7 +51,7 @@ class ngraph::runtime::gcpu::NodeWrapper
public: public:
NodeWrapper(const std::shared_ptr<const ngraph::Node>& node); NodeWrapper(const std::shared_ptr<const ngraph::Node>& node);
const Node& get_node() const { return *m_node; } std::shared_ptr<const Node> get_node() const { return m_node; }
ngraph::runtime::gcpu::OP_TYPEID get_typeid() const { return m_typeid; } ngraph::runtime::gcpu::OP_TYPEID get_typeid() const { return m_typeid; }
private: private:
std::shared_ptr<const ngraph::Node> m_node; std::shared_ptr<const ngraph::Node> m_node;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment