Commit 6ca1a511 authored by amy.zhuang's avatar amy.zhuang

Merge branch 'master' into ayzhuang/in-place-concat

parents 749a1c6b 9debc0bc
...@@ -68,7 +68,7 @@ function build_ngraph() { ...@@ -68,7 +68,7 @@ function build_ngraph() {
make install || return 1 make install || return 1
cd "${ngraph_directory}/ngraph/python" cd "${ngraph_directory}/ngraph/python"
if [ ! -d ./pybind11 ]; then if [ ! -d ./pybind11 ]; then
git clone --recursive -b allow-nonconstructible-holders https://github.com/jagerman/pybind11.git git clone --recursive https://github.com/pybind/pybind11.git
fi fi
export PYBIND_HEADERS_PATH="${ngraph_directory}/ngraph/python/pybind11" export PYBIND_HEADERS_PATH="${ngraph_directory}/ngraph/python/pybind11"
export NGRAPH_CPP_BUILD_PATH="${ngraph_directory}/ngraph_dist" export NGRAPH_CPP_BUILD_PATH="${ngraph_directory}/ngraph_dist"
......
...@@ -39,7 +39,7 @@ RUN make install ...@@ -39,7 +39,7 @@ RUN make install
# Prepare nGraph Python API # Prepare nGraph Python API
WORKDIR /root/ngraph/python WORKDIR /root/ngraph/python
RUN git clone --recursive -b allow-nonconstructible-holders https://github.com/jagerman/pybind11.git RUN git clone --recursive https://github.com/pybind/pybind11.git
ENV NGRAPH_CPP_BUILD_PATH /root/ngraph_dist ENV NGRAPH_CPP_BUILD_PATH /root/ngraph_dist
ENV LD_LIBRARY_PATH /root/ngraph_dist/lib ENV LD_LIBRARY_PATH /root/ngraph_dist/lib
ENV PYBIND_HEADERS_PATH /root/ngraph/python/pybind11 ENV PYBIND_HEADERS_PATH /root/ngraph/python/pybind11
......
...@@ -196,25 +196,8 @@ if(WIN32) ...@@ -196,25 +196,8 @@ if(WIN32)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_CRT_SECURE_NO_WARNINGS") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_CRT_SECURE_NO_WARNINGS")
endif() endif()
include(unit_test_control)
set(UNIT_TEST_CONFIG_LIST "" CACHE INTERNAL "")
if (NGRAPH_INTERPRETER_ENABLE)
unit_test_control(BACKEND INTERPRETER MANIFEST src/ngraph/runtime/interpreter/unit_test.manifest)
endif()
# Set true if CPU backend is built by default
if (NGRAPH_CPU_ENABLE) if (NGRAPH_CPU_ENABLE)
unit_test_control(BACKEND CPU MANIFEST src/ngraph/runtime/cpu/unit_test.manifest) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DNGRAPH_CPU_ENABLE")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DNGRAPH_CPU_ENABLE")
endif()
if (NGRAPH_INTELGPU_ENABLE)
unit_test_control(BACKEND INTELGPU MANIFEST src/ngraph/runtime/intelgpu/unit_test.manifest)
endif()
if (NGRAPH_GPU_ENABLE)
unit_test_control(BACKEND GPU MANIFEST src/ngraph/runtime/gpu/unit_test.manifest)
endif() endif()
if (NOT DEFINED NGRAPH_TBB_ENABLE) if (NOT DEFINED NGRAPH_TBB_ENABLE)
......
# nGraph Library [![Build Status][build-status-badge]][build-status] # nGraph Library [![Build Status][build-status-badge]][build-status]
Welcome to the open-source repository for the **Intel® nGraph Library**. Our code
Welcome to the open-source repository for the Intel® nGraph™ Library. Our code
base provides a Compiler and runtime suite of tools (APIs) designed to give base provides a Compiler and runtime suite of tools (APIs) designed to give
developers maximum flexibility for their software design, allowing them to developers maximum flexibility for their software design, allowing them to
create or customize a scalable solution using any framework while also avoiding create or customize a scalable solution using any framework while also avoiding
...@@ -11,7 +10,9 @@ backends, and it will be able to run on any backends we support in the future ...@@ -11,7 +10,9 @@ backends, and it will be able to run on any backends we support in the future
with minimal disruption to your model. With nGraph, you can co-evolve your with minimal disruption to your model. With nGraph, you can co-evolve your
software and hardware's capabilities to stay at the forefront of your industry. software and hardware's capabilities to stay at the forefront of your industry.
The nGraph Compiler is Intel's graph compiler for Artificial Neural Networks. ![nGraph ecosystem][ngraph-ecosystem]
The **nGraph Compiler** is Intel's graph compiler for Artificial Neural Networks.
Documentation in this repo describes how you can program any framework Documentation in this repo describes how you can program any framework
to run training and inference computations on a variety of Backends including to run training and inference computations on a variety of Backends including
Intel® Architecture Processors (CPUs), Intel® Nervana™ Neural Network Processors Intel® Architecture Processors (CPUs), Intel® Nervana™ Neural Network Processors
...@@ -24,18 +25,29 @@ whatever scenario you need. ...@@ -24,18 +25,29 @@ whatever scenario you need.
nGraph provides both a C++ API for framework developers and a Python API which nGraph provides both a C++ API for framework developers and a Python API which
can run inference on models imported from ONNX. can run inference on models imported from ONNX.
![nGraph ecosystem][ngraph-ecosystem] See the [Release Notes] for recent changes.
| Framework | bridge available? | ONNX support? |
|----------------|-------------------|----------------|
| TensorFlow* | yes | yes |
| MXNet* | yes | yes |
| PaddlePaddle | yes | yes |
| PyTorch* | no | yes |
| Chainer* | no | yes |
| CNTK* | no | yes |
| Caffe2* | no | yes |
|Framework | bridge available? | ONNX support? | | Backend | current support | future support |
|------------|-------------------|----------------| |-----------------------------------------------|-------------------|----------------|
| neon | yes | yes | | Intel® Architecture CPU | yes | yes |
| MXNet* | yes | yes | | Intel® Nervana™ Neural Network Processor (NNP)| yes | yes |
| TensorFlow*| yes | yes | | Intel [Movidius™ Myriad™ 2] VPUs | coming soon | yes |
| PyTorch* | not yet | yes | | Intel® Architecture GPUs | via PlaidML | yes |
| Chainer* | not yet | yes | | AMD* GPUs | via PlaidML | yes |
| CNTK* | not yet | yes | | NVIDIA* GPUs | via PlaidML | some |
| Caffe2* | not yet | yes | | Field Programmable Gate Arrays (FPGA) | no | yes |
## Documentation ## Documentation
...@@ -72,12 +84,13 @@ to improve the Library: ...@@ -72,12 +84,13 @@ to improve the Library:
[install]: http://ngraph.nervanasys.com/docs/latest/buildlb.html [install]: http://ngraph.nervanasys.com/docs/latest/buildlb.html
[framework integration guides]: http://ngraph.nervanasys.com/docs/latest/framework-integration-guides.html [framework integration guides]: http://ngraph.nervanasys.com/docs/latest/framework-integration-guides.html
[release notes]: http://ngraph.nervanasys.com/docs/latest/project/release-notes.html
[Github issues]: https://github.com/NervanaSystems/ngraph/issues [Github issues]: https://github.com/NervanaSystems/ngraph/issues
[contrib guide]: http://ngraph.nervanasys.com/docs/latest/project/code-contributor-README.html [contrib guide]: http://ngraph.nervanasys.com/docs/latest/project/code-contributor-README.html
[pull request]: https://github.com/NervanaSystems/ngraph/pulls [pull request]: https://github.com/NervanaSystems/ngraph/pulls
[how to import]: http://ngraph.nervanasys.com/docs/latest/howto/import.html [how to import]: http://ngraph.nervanasys.com/docs/latest/howto/import.html
[ngraph-ecosystem]: doc/sphinx/source/graphics/ngraph-ecosystem.png "nGraph Ecosystem" [ngraph-ecosystem]: doc/sphinx/source/graphics/599px-Intel-ngraph-ecosystem.png "nGraph Ecosystem"
[build-status]: https://travis-ci.org/NervanaSystems/ngraph/branches [build-status]: https://travis-ci.org/NervanaSystems/ngraph/branches
[build-status-badge]: https://travis-ci.org/NervanaSystems/ngraph.svg?branch=master [build-status-badge]: https://travis-ci.org/NervanaSystems/ngraph.svg?branch=master
[develop-without-lockin]: doc/sphinx/source/graphics/develop-without-lockin.png "Develop on any part of the stack wtihout lockin" [develop-without-lockin]: doc/sphinx/source/graphics/develop-without-lockin.png "Develop on any part of the stack wtihout lockin"
[Movidius]:https://www.movidius.com/solutions/vision-processing-unit [Movidius™ Myriad™ 2]:https://www.movidius.com/solutions/vision-processing-unit
# ******************************************************************************
# Copyright 2017-2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
function(UNIT_TEST_CONTROL)
set(options)
set(oneValueArgs BACKEND MANIFEST)
set(multiValueArgs)
cmake_parse_arguments(UNIT_TEST_CONTROL "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
if (UNIT_TEST_CONTROL_MANIFEST)
get_filename_component(UNIT_TEST_CONTROL_MANIFEST ${UNIT_TEST_CONTROL_MANIFEST} ABSOLUTE)
set(CONFIG_STRING "${UNIT_TEST_CONTROL_BACKEND}@${UNIT_TEST_CONTROL_MANIFEST}")
else()
set(CONFIG_STRING "${UNIT_TEST_CONTROL_BACKEND}@")
endif()
set(UNIT_TEST_CONFIG_LIST "${UNIT_TEST_CONFIG_LIST};${CONFIG_STRING}" CACHE INTERNAL "")
endfunction()
...@@ -68,9 +68,9 @@ author = 'Intel Corporation' ...@@ -68,9 +68,9 @@ author = 'Intel Corporation'
# built documents. # built documents.
# #
# The short X.Y version. # The short X.Y version.
version = '0.8' version = '0.9'
# The full version, including alpha/beta/rc tags. # The full version, including alpha/beta/rc tags.
release = '0.8.0' release = '0.9.0'
# The language for content autogenerated by Sphinx. Refer to documentation # The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages. # for a list of supported languages.
......
...@@ -30,8 +30,8 @@ software engineers, and others with the means to make their work :ref:`portable` ...@@ -30,8 +30,8 @@ software engineers, and others with the means to make their work :ref:`portable`
:abbr:`Machine Learning (ML)` hardware available today: optimized Deep Learning :abbr:`Machine Learning (ML)` hardware available today: optimized Deep Learning
computation devices. computation devices.
.. figure:: graphics/ngraph-ecosystem.png .. figure:: graphics/599px-Intel-ngraph-ecosystem.png
:width: 650px :width: 599px
.. _portable: .. _portable:
...@@ -66,16 +66,15 @@ Python-based API. See the `ngraph onnx companion tool`_ to get started. ...@@ -66,16 +66,15 @@ Python-based API. See the `ngraph onnx companion tool`_ to get started.
.. csv-table:: .. csv-table::
:header: "Framework", "Bridge Code Available?", "ONNX Support?" :header: "Framework", "Bridge Available?", "ONNX Support?"
:widths: 27, 10, 10 :widths: 27, 10, 10
TensorFlow, Yes, Yes TensorFlow, Yes, Yes
MXNet, Yes, Yes MXNet, Yes, Yes
PaddlePaddle, Coming Soon, Yes PaddlePaddle, Coming Soon, Yes
neon, none needed, Yes PyTorch, No, Yes
PyTorch, Coming Soon, Yes CNTK, No, Yes
CNTK, Not yet, Yes Other, Custom, Custom
Other, Not yet, Doable
.. _deployable: .. _deployable:
...@@ -104,29 +103,30 @@ model to run on a variety of backends: ...@@ -104,29 +103,30 @@ model to run on a variety of backends:
.. csv-table:: .. csv-table::
:header: "Backend", "Current nGraph support", "Future nGraph support" :header: "Backend", "Current support", "Future nGraph support"
:widths: 35, 10, 10 :widths: 35, 10, 10
Intel® Architecture Processors (CPUs), Yes, Yes Intel® Architecture Processors (CPUs), Yes, Yes
Intel® Nervana™ Neural Network Processor™ (NNPs), Yes, Yes Intel® Nervana™ Neural Network Processor (NNPs), Yes, Yes
NVIDIA\* CUDA (GPUs), Yes, Some AMD\* GPUs, via PlaidML, Yes
NVIDIA\* GPUs, via PlaidML, Some
Intel® Architecture GPUs, Yes, Yes
:abbr:`Field Programmable Gate Arrays (FPGA)` (FPGAs), Coming soon, Yes :abbr:`Field Programmable Gate Arrays (FPGA)` (FPGAs), Coming soon, Yes
`Movidius`_, Not yet, Yes Intel Movidius™ Myriad™ 2 (VPU), Coming soon, Yes
Other, Not yet, Ask Other, Not yet, Ask
The value we're offering to the developer community is empowerment: we are The value we're offering to the developer community is empowerment: we are
confident that Intel® Architecture already provides the best computational confident that Intel® Architecture already provides the best computational
resources available for the breadth of ML/DL tasks. We welcome ideas and resources available for the breadth of ML/DL tasks. We welcome ideas and
`contributions`_ from the community. `contributions`_ from the community.
Further project details can be found on our :doc:`project/about` page, or see Further project details can be found on our :doc:`project/about` page, or see
our :doc:`buildlb` guide for how to get started. our :doc:`buildlb` guide for how to get started.
.. note:: The library code is under active development as we're continually .. note:: The Library code is under active development as we're continually
adding support for more kinds of DL models and ops, framework compiler adding support for more kinds of DL models and ops, framework compiler
optimizations, and backends. optimizations, and backends.
======= =======
...@@ -152,7 +152,6 @@ Contents ...@@ -152,7 +152,6 @@ Contents
project/index.rst project/index.rst
Indices and tables Indices and tables
================== ==================
...@@ -160,8 +159,7 @@ Indices and tables ...@@ -160,8 +159,7 @@ Indices and tables
* :ref:`genindex` * :ref:`genindex`
.. _ONNX: http://onnx.ai .. _ONNX: http://onnx.ai
.. _ngraph onnx companion tool: https://github.com/NervanaSystems/ngraph-onnx .. _ngraph onnx companion tool: https://github.com/NervanaSystems/ngraph-onnx
.. _Movidius: https://www.movidius.com/ .. _Movidius: https://www.movidius.com/
.. _contributions: https://github.com/NervanaSystems/ngraph#how-to-contribute .. _contributions: https://github.com/NervanaSystems/ngraph#how-to-contribute
\ No newline at end of file
...@@ -5,7 +5,7 @@ Overview ...@@ -5,7 +5,7 @@ Overview
Welcome to the documentation site for |InG|, an open-source C++ Compiler, Welcome to the documentation site for |InG|, an open-source C++ Compiler,
Library, and runtime suite for running training and inference on Library, and runtime suite for Deep Learning frameworks running training and inference on
:abbr:`Deep Neural Network (DNN)` models. nGraph is framework-neutral and can be :abbr:`Deep Neural Network (DNN)` models. nGraph is framework-neutral and can be
targeted for programming and deploying :abbr:`Deep Learning (DL)` applications targeted for programming and deploying :abbr:`Deep Learning (DL)` applications
on the most modern compute and edge devices. on the most modern compute and edge devices.
...@@ -22,8 +22,8 @@ Features ...@@ -22,8 +22,8 @@ Features
Develop without lock-in Develop without lock-in
~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~
.. figure:: ../graphics/develop-without-lockin.png .. figure:: ../graphics/599px-Intel-ngraph-ecosystem.png
:width: 650px :width: 599px
Being able to increase training performance or reduce inference latency by simply Being able to increase training performance or reduce inference latency by simply
...@@ -34,47 +34,6 @@ developers working with nGraph. Our commitment to bake flexibility into our ...@@ -34,47 +34,6 @@ developers working with nGraph. Our commitment to bake flexibility into our
ecosystem ensures developers' freedom to design user-facing APIs for various ecosystem ensures developers' freedom to design user-facing APIs for various
hardware deployments directly into their frameworks. hardware deployments directly into their frameworks.
.. figure:: ../graphics/ngraph-ecosystem.png
:width: 585px
nGraph currently supports :doc:`three popular <../framework-integration-guides>`
frameworks for :abbr:`Deep Learning (DL)` models through what we call
a :term:`bridge` that can be integrated during the framework's build time.
For developers working with other frameworks (even those not listed above),
we've created a :doc:`How to Guide <../howto/index>` so you can learn how to
create custom bridge code that can be used to
:doc:`compile and run <../howto/execute>` a training model.
Additionally, nGraph Library supports the `ONNX`_ format. Developers who
already have a "trained" model can use nGraph to bypass much of the
framework-based complexity and :doc:`../howto/import` to test or run it
on targeted and efficient backends with our user-friendly ``ngraph_api``.
With nGraph, data scientists can focus on data science rather than worrying
about how to adapt models to train and run efficiently on different devices.
Be sure to add the ``-DNGRAPH_ONNX_IMPORT_ENABLE=ON`` option when running `cmake`
to build the Library.
Supported platforms
--------------------
* Intel® Architecture Processors (CPUs),
* Intel® Nervana™ Neural Network Processor™ (NNPs), and
* NVIDIA\* CUDA (GPUs).
We built the first-generation of the Intel Nervana™ NNP family of processors
last year to show that the nGraph Library can be used to train a
:abbr:`Neural Network (NN)` more quickly. The more advanced the silicon, the
more powerful a lightweight a library can be. So while we do currently support
traditional GPUs, they are not advanced silicon, and trying to scale workloads
using traditional GPU libraries is clunky and brittle with bottlenecks. Iteration
from an already-trained NN model to one that can also perform inference
computations is immensely simplified. Read more about these compute-friendly
options on the documentation for :doc:`../fusion/index`.
.. note:: The library code is under active development as we're continually .. note:: The library code is under active development as we're continually
adding support for more kinds of DL models and ops, framework compiler adding support for more kinds of DL models and ops, framework compiler
optimizations, and backends. optimizations, and backends.
......
...@@ -15,13 +15,13 @@ ...@@ -15,13 +15,13 @@
.. limitations under the License. .. limitations under the License.
.. --------------------------------------------------------------------------- .. ---------------------------------------------------------------------------
nGraph Library docs nGraph Library docs
==================== ===================
Read this for changes affecting anything in ``ngraph/doc`` Read this for changes affecting anything in ``ngraph/doc``
---------------------------------------------------------- ----------------------------------------------------------
For updates to the Intel® nGraph Library ``/doc`` repo, please submit a PR with For updates to the Intel® nGraph Library ``/doc`` repo, please submit a PR with
any changes or ideas you'd like integrated. This helps us maintain trackability any changes or ideas you'd like integrated. This helps us maintain trackability
with respect to additions or feature requests. with respect to additions or feature requests.
...@@ -127,7 +127,7 @@ To build documentation locally, run: ...@@ -127,7 +127,7 @@ To build documentation locally, run:
.. code-block:: console .. code-block:: console
$ pip3 install [-I] Sphinx==1.6.5 [--user] $ pip3 install [-I] Sphinx==1.7.5 [--user]
$ pip3 install [-I] breathe numpy [--user] $ pip3 install [-I] breathe numpy [--user]
$ cd doc/sphinx/ $ cd doc/sphinx/
$ make html $ make html
...@@ -150,11 +150,16 @@ To build documentation in a python3 virtualenv, run: ...@@ -150,11 +150,16 @@ To build documentation in a python3 virtualenv, run:
Then point your browser at ``localhost:8000``. Then point your browser at ``localhost:8000``.
.. note:: For docs built in a virtual env, Sphinx latest changes may break
documentation; try building with a specific version of Sphinx.
For tips on writing reStructuredText-formatted documentation, see the `sphinx`_ For tips on writing reStructuredText-formatted documentation, see the `sphinx`_
stable reST documentation. stable reST documentation.
.. _ngraph repo: https://github.com/NervanaSystems/ngraph-cpp/ .. _ngraph repo: https://github.com/NervanaSystems/ngraph/
.. _documentation repo: https://github.com/NervanaSystems/private-ngraph/tree/master/doc .. _documentation repo: https://github.com/NervanaSystems/ngraph/tree/master/doc
.. _sphinx: http://www.sphinx-doc.org/en/stable/rest.html .. _sphinx: http://www.sphinx-doc.org/en/stable/rest.html
.. _wiki: https://github.com/NervanaSystems/ngraph/wiki/ .. _wiki: https://github.com/NervanaSystems/ngraph/wiki/
.. _breathe: https://breathe.readthedocs.io/en/latest/ .. _breathe: https://breathe.readthedocs.io/en/latest/
......
...@@ -3,5 +3,12 @@ ...@@ -3,5 +3,12 @@
Release Notes Release Notes
############# #############
This is the |version| of release. This is release |release|.
API Changes
===========
.. literalinclude:: ../../../../changes.md
...@@ -29,12 +29,16 @@ namespace ngraph ...@@ -29,12 +29,16 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector abs(const Node& node) namespace set_1
{ {
return {std::make_shared<ngraph::op::Abs>(node.get_ng_inputs().at(0))}; inline NodeVector abs(const Node& node)
} {
return {std::make_shared<ngraph::op::Abs>(node.get_ng_inputs().at(0))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,14 +28,18 @@ namespace ngraph ...@@ -28,14 +28,18 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector add(const Node& node) namespace set_1
{ {
NodeVector ng_inputs{ inline NodeVector add(const Node& node)
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())}; {
return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))}; NodeVector ng_inputs{
} numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Add>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,14 +28,18 @@ namespace ngraph ...@@ -28,14 +28,18 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector logical_and(const Node& node) namespace set_1
{ {
NodeVector ng_inputs{ inline NodeVector logical_and(const Node& node)
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())}; {
return {std::make_shared<ngraph::op::And>(ng_inputs.at(0), ng_inputs.at(1))}; NodeVector ng_inputs{
} numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::And>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,12 +26,16 @@ namespace ngraph ...@@ -26,12 +26,16 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector average_pool(const Node& node) namespace set_1
{ {
return convpool::make_ng_pool<ngraph::op::AvgPool>(node); NodeVector average_pool(const Node& node)
} {
return convpool::make_ng_pool<ngraph::op::AvgPool>(node);
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,17 +26,21 @@ namespace ngraph ...@@ -26,17 +26,21 @@ namespace ngraph
{ {
namespace op namespace op
{ {
/** namespace set_1
* @brief Convert ONNX AveragePool operation to an nGraph node. {
* /**
* @param node The ONNX node object representing this operation. * @brief Convert ONNX AveragePool operation to an nGraph node.
* *
* @return The vector containing Ngraph nodes producing output of ONNX AveragePool * @param node The ONNX node object representing this operation.
* operation. *
*/ * @return The vector containing Ngraph nodes producing output of ONNX AveragePool
NodeVector average_pool(const Node& node); * operation.
*/
} // namespace op NodeVector average_pool(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,38 +28,42 @@ namespace ngraph ...@@ -28,38 +28,42 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector batch_norm(const Node& node) namespace set_1
{ {
NodeVector inputs{node.get_ng_inputs()}; NodeVector batch_norm(const Node& node)
auto x = inputs.at(0); {
auto scale = inputs.at(1); NodeVector inputs{node.get_ng_inputs()};
auto bias = inputs.at(2); auto x = inputs.at(0);
std::shared_ptr<ngraph::Node> mean{nullptr}; auto scale = inputs.at(1);
std::shared_ptr<ngraph::Node> var{nullptr}; auto bias = inputs.at(2);
std::shared_ptr<ngraph::Node> mean{nullptr};
std::shared_ptr<ngraph::Node> var{nullptr};
int is_test{node.get_attribute_value<int>("is_test", 1)}; int is_test{node.get_attribute_value<int>("is_test", 1)};
int spatial{node.get_attribute_value<int>("spatial", 1)}; int spatial{node.get_attribute_value<int>("spatial", 1)};
double epsilon{node.get_attribute_value<double>("epsilon", 1e-5)}; double epsilon{node.get_attribute_value<double>("epsilon", 1e-5)};
// TODO: Implement learning mode support // TODO: Implement learning mode support
// float momentum{node.get_attribute_value<float>("momentum", 0.9f)}; // float momentum{node.get_attribute_value<float>("momentum", 0.9f)};
bool training = false; bool training = false;
ASSERT_IS_SUPPORTED(node, is_test) << "only 'is_test' mode is supported."; ASSERT_IS_SUPPORTED(node, is_test) << "only 'is_test' mode is supported.";
ASSERT_IS_SUPPORTED(node, spatial) << "only 'spatial' mode is supported."; ASSERT_IS_SUPPORTED(node, spatial) << "only 'spatial' mode is supported.";
if (inputs.size() >= 5) if (inputs.size() >= 5)
{ {
mean = inputs.at(3); mean = inputs.at(3);
var = inputs.at(4); var = inputs.at(4);
return {std::make_shared<ngraph::op::BatchNorm>( return {std::make_shared<ngraph::op::BatchNorm>(
epsilon, scale, bias, x, mean, var, training)}; epsilon, scale, bias, x, mean, var, training)};
}
return {std::make_shared<ngraph::op::BatchNorm>(epsilon, scale, bias, x)};
} }
return {std::make_shared<ngraph::op::BatchNorm>(epsilon, scale, bias, x)}; } // namespace set_1
}
} // namespace op } //namespace op
} // namespace onnx_import } // namespace onnx_import
} // namespace ngraph } // namespace ngraph
...@@ -24,9 +24,14 @@ namespace ngraph ...@@ -24,9 +24,14 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector batch_norm(const Node& node); namespace set_1
} // namespace op {
NodeVector batch_norm(const Node& node);
} // namespace onnx_import } // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph } // namespace ngraph
...@@ -30,34 +30,40 @@ namespace ngraph ...@@ -30,34 +30,40 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector cast(const Node& node) namespace set_1
{ {
auto data = node.get_ng_inputs().at(0); NodeVector cast(const Node& node)
int64_t target_type = node.get_attribute_value<int64_t>("to");
element::Type elem_type;
switch (target_type)
{ {
case onnx::TensorProto_DataType_BOOL: elem_type = element::boolean; break; auto data = node.get_ng_inputs().at(0);
case onnx::TensorProto_DataType_DOUBLE: elem_type = element::f64; break; int64_t target_type = node.get_attribute_value<int64_t>("to");
case onnx::TensorProto_DataType_FLOAT16: element::Type elem_type;
case onnx::TensorProto_DataType_FLOAT: elem_type = element::f32; break;
case onnx::TensorProto_DataType_INT8: elem_type = element::i8; break; switch (target_type)
case onnx::TensorProto_DataType_INT16: elem_type = element::i16; break; {
case onnx::TensorProto_DataType_INT32: elem_type = element::i32; break; case onnx::TensorProto_DataType_BOOL: elem_type = element::boolean; break;
case onnx::TensorProto_DataType_INT64: elem_type = element::i64; break; case onnx::TensorProto_DataType_DOUBLE: elem_type = element::f64; break;
case onnx::TensorProto_DataType_UINT8: elem_type = element::u8; break; case onnx::TensorProto_DataType_FLOAT16:
case onnx::TensorProto_DataType_UINT16: elem_type = element::u16; break; case onnx::TensorProto_DataType_FLOAT: elem_type = element::f32; break;
case onnx::TensorProto_DataType_UINT32: elem_type = element::u32; break; case onnx::TensorProto_DataType_INT8: elem_type = element::i8; break;
case onnx::TensorProto_DataType_UINT64: elem_type = element::u64; break; case onnx::TensorProto_DataType_INT16: elem_type = element::i16; break;
case onnx::TensorProto_DataType_UNDEFINED: elem_type = element::unspecified; break; case onnx::TensorProto_DataType_INT32: elem_type = element::i32; break;
default: ASSERT_IS_SUPPORTED(node, false) << "unsupported type"; case onnx::TensorProto_DataType_INT64: elem_type = element::i64; break;
case onnx::TensorProto_DataType_UINT8: elem_type = element::u8; break;
case onnx::TensorProto_DataType_UINT16: elem_type = element::u16; break;
case onnx::TensorProto_DataType_UINT32: elem_type = element::u32; break;
case onnx::TensorProto_DataType_UINT64: elem_type = element::u64; break;
case onnx::TensorProto_DataType_UNDEFINED:
elem_type = element::unspecified;
break;
default: ASSERT_IS_SUPPORTED(node, false) << "unsupported type";
}
return {std::make_shared<ngraph::op::Convert>(data, elem_type)};
} }
return {std::make_shared<ngraph::op::Convert>(data, elem_type)}; } // namespace set_1
}
} // namespace op } //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,8 +26,13 @@ namespace ngraph ...@@ -26,8 +26,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector cast(const Node& node); namespace set_1
} // namespace op {
NodeVector cast(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -29,12 +29,16 @@ namespace ngraph ...@@ -29,12 +29,16 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector ceil(const Node& node) namespace set_1
{ {
return {std::make_shared<ngraph::op::Ceiling>(node.get_ng_inputs().at(0))}; inline NodeVector ceil(const Node& node)
} {
return {std::make_shared<ngraph::op::Ceiling>(node.get_ng_inputs().at(0))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -34,30 +34,37 @@ namespace ngraph ...@@ -34,30 +34,37 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector clip(const Node& node) namespace set_1
{ {
auto data = node.get_ng_inputs().at(0); NodeVector clip(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
double max_value = double max_value =
node.get_attribute_value<double>("max", std::numeric_limits<double>::max()); node.get_attribute_value<double>("max", std::numeric_limits<double>::max());
double min_value = double min_value = node.get_attribute_value<double>(
node.get_attribute_value<double>("min", std::numeric_limits<double>::lowest()); "min", std::numeric_limits<double>::lowest());
std::shared_ptr<ngraph::Node> max_value_node = std::shared_ptr<ngraph::Node> max_value_node =
std::make_shared<ngraph::op::Constant>( std::make_shared<ngraph::op::Constant>(data->get_element_type(),
data->get_element_type(), ngraph::Shape{}, std::vector<double>{max_value}); ngraph::Shape{},
max_value_node = make_broadcast_node(max_value_node, data->get_shape()); std::vector<double>{max_value});
max_value_node = make_broadcast_node(max_value_node, data->get_shape());
std::shared_ptr<ngraph::Node> min_value_node = std::shared_ptr<ngraph::Node> min_value_node =
std::make_shared<ngraph::op::Constant>( std::make_shared<ngraph::op::Constant>(data->get_element_type(),
data->get_element_type(), ngraph::Shape{}, std::vector<double>{min_value}); ngraph::Shape{},
min_value_node = make_broadcast_node(min_value_node, data->get_shape()); std::vector<double>{min_value});
min_value_node = make_broadcast_node(min_value_node, data->get_shape());
return {std::make_shared<ngraph::op::Minimum>( return {std::make_shared<ngraph::op::Minimum>(
max_value_node, std::make_shared<ngraph::op::Maximum>(data, min_value_node))}; max_value_node,
} std::make_shared<ngraph::op::Maximum>(data, min_value_node))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,8 +26,13 @@ namespace ngraph ...@@ -26,8 +26,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector clip(const Node& node); namespace set_1
} // namespace op {
NodeVector clip(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -24,16 +24,20 @@ namespace ngraph ...@@ -24,16 +24,20 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector concat(const Node& node) namespace set_1
{ {
NodeVector inputs{node.get_ng_inputs()}; NodeVector concat(const Node& node)
auto axis = node.get_attribute_value<int64_t>("axis"); {
NodeVector inputs{node.get_ng_inputs()};
auto axis = node.get_attribute_value<int64_t>("axis");
return {std::make_shared<ngraph::op::Concat>(inputs, axis)}; return {std::make_shared<ngraph::op::Concat>(inputs, axis)};
} }
} // namespace op } // namespace set_1
} // namespace onnx_import } //namespace op
} // namespace ngraph } // namespace onnx_import
} // namespace ngraph
...@@ -26,9 +26,14 @@ namespace ngraph ...@@ -26,9 +26,14 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector concat(const Node& node); namespace set_1
} // namespace op {
NodeVector concat(const Node& node);
} // namespace onnx_import } // namespace set_1
} // namespace ngraph } //namespace op
} // namespace onnx_import
} // namespace ngraph
...@@ -25,96 +25,101 @@ namespace ngraph ...@@ -25,96 +25,101 @@ namespace ngraph
{ {
namespace op namespace op
{ {
namespace namespace set_1
{ {
template <typename T> namespace
inline std::shared_ptr<ngraph::op::Constant>
__make_ng_constant(const element::Type& type, const Tensor& tensor)
{ {
return std::make_shared<ngraph::op::Constant>( template <typename T>
type, tensor.get_shape(), tensor.get_data<T>()); inline std::shared_ptr<ngraph::op::Constant>
} __make_ng_constant(const element::Type& type, const Tensor& tensor)
{
return std::make_shared<ngraph::op::Constant>(
type, tensor.get_shape(), tensor.get_data<T>());
}
template <Tensor::Type> template <Tensor::Type>
inline std::shared_ptr<ngraph::op::Constant> make_ng_constant(const Tensor& tensor) inline std::shared_ptr<ngraph::op::Constant>
{ make_ng_constant(const Tensor& tensor)
throw error::tensor::unsupported_data_type{tensor}; {
} throw error::tensor::unsupported_data_type{tensor};
}
template <> template <>
inline std::shared_ptr<ngraph::op::Constant> inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::float16>(const Tensor& tensor) make_ng_constant<Tensor::Type::float16>(const Tensor& tensor)
{ {
return __make_ng_constant<float>(element::f32, tensor); return __make_ng_constant<float>(element::f32, tensor);
} }
template <> template <>
inline std::shared_ptr<ngraph::op::Constant> inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::float32>(const Tensor& tensor) make_ng_constant<Tensor::Type::float32>(const Tensor& tensor)
{ {
return __make_ng_constant<float>(element::f32, tensor); return __make_ng_constant<float>(element::f32, tensor);
} }
template <> template <>
inline std::shared_ptr<ngraph::op::Constant> inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::float64>(const Tensor& tensor) make_ng_constant<Tensor::Type::float64>(const Tensor& tensor)
{ {
return __make_ng_constant<double>(element::f64, tensor); return __make_ng_constant<double>(element::f64, tensor);
} }
template <> template <>
inline std::shared_ptr<ngraph::op::Constant> inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::int32>(const Tensor& tensor) make_ng_constant<Tensor::Type::int32>(const Tensor& tensor)
{ {
return __make_ng_constant<int32_t>(element::i32, tensor); return __make_ng_constant<int32_t>(element::i32, tensor);
} }
template <> template <>
inline std::shared_ptr<ngraph::op::Constant> inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::int64>(const Tensor& tensor) make_ng_constant<Tensor::Type::int64>(const Tensor& tensor)
{ {
return __make_ng_constant<int64_t>(element::i64, tensor); return __make_ng_constant<int64_t>(element::i64, tensor);
} }
template <> template <>
inline std::shared_ptr<ngraph::op::Constant> inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::uint32>(const Tensor& tensor) make_ng_constant<Tensor::Type::uint32>(const Tensor& tensor)
{ {
return __make_ng_constant<uint32_t>(element::u32, tensor); return __make_ng_constant<uint32_t>(element::u32, tensor);
} }
template <> template <>
inline std::shared_ptr<ngraph::op::Constant> inline std::shared_ptr<ngraph::op::Constant>
make_ng_constant<Tensor::Type::uint64>(const Tensor& tensor) make_ng_constant<Tensor::Type::uint64>(const Tensor& tensor)
{ {
return __make_ng_constant<uint64_t>(element::u64, tensor); return __make_ng_constant<uint64_t>(element::u64, tensor);
} }
inline std::shared_ptr<ngraph::op::Constant> make_constant(const Tensor& tensor) inline std::shared_ptr<ngraph::op::Constant> make_constant(const Tensor& tensor)
{ {
#define MAKE_NG_CONSTANT(data_type_) \ #define MAKE_NG_CONSTANT(data_type_) \
case data_type_: return make_ng_constant<data_type_>(tensor) case data_type_: return make_ng_constant<data_type_>(tensor)
switch (tensor.get_type()) switch (tensor.get_type())
{ {
MAKE_NG_CONSTANT(Tensor::Type::float16); MAKE_NG_CONSTANT(Tensor::Type::float16);
MAKE_NG_CONSTANT(Tensor::Type::float32); MAKE_NG_CONSTANT(Tensor::Type::float32);
MAKE_NG_CONSTANT(Tensor::Type::float64); MAKE_NG_CONSTANT(Tensor::Type::float64);
MAKE_NG_CONSTANT(Tensor::Type::int32); MAKE_NG_CONSTANT(Tensor::Type::int32);
MAKE_NG_CONSTANT(Tensor::Type::int64); MAKE_NG_CONSTANT(Tensor::Type::int64);
MAKE_NG_CONSTANT(Tensor::Type::uint32); MAKE_NG_CONSTANT(Tensor::Type::uint32);
MAKE_NG_CONSTANT(Tensor::Type::uint64); MAKE_NG_CONSTANT(Tensor::Type::uint64);
default: throw error::tensor::invalid_data_type{tensor}; default: throw error::tensor::invalid_data_type{tensor};
}
} }
} }
}
NodeVector constant(const onnx_import::Node& node) NodeVector constant(const onnx_import::Node& node)
{ {
return {make_constant(node.get_attribute_value<Tensor>("value"))}; return {make_constant(node.get_attribute_value<Tensor>("value"))};
} }
} // namespace set_1
} // namespace op } //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,9 +26,13 @@ namespace ngraph ...@@ -26,9 +26,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector constant(const Node& node); namespace set_1
{
NodeVector constant(const Node& node);
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
This diff is collapsed.
...@@ -26,15 +26,19 @@ namespace ngraph ...@@ -26,15 +26,19 @@ namespace ngraph
{ {
namespace op namespace op
{ {
/// \brief Performs ONNX Conv operation. namespace set_1
/// {
/// \param node The ONNX node object representing this operation. /// \brief Performs ONNX Conv operation.
/// ///
/// \return The vector containing Ngraph nodes producing output of ONNX convolution /// \param node The ONNX node object representing this operation.
/// operation. ///
NodeVector conv(const Node& node); /// \return The vector containing Ngraph nodes producing output of ONNX convolution
/// operation.
} // namespace op NodeVector conv(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,14 +28,18 @@ namespace ngraph ...@@ -28,14 +28,18 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector div(const Node& node) namespace set_1
{ {
NodeVector ng_inputs{ inline NodeVector div(const Node& node)
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())}; {
return {std::make_shared<ngraph::op::Divide>(ng_inputs.at(0), ng_inputs.at(1))}; NodeVector ng_inputs{
} numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Divide>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -38,26 +38,33 @@ namespace ngraph ...@@ -38,26 +38,33 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector elu(const Node& node) namespace set_1
{ {
auto data = node.get_ng_inputs().at(0); NodeVector elu(const Node& node)
double alpha = node.get_attribute_value<double>("alpha", 1); {
auto data = node.get_ng_inputs().at(0);
double alpha = node.get_attribute_value<double>("alpha", 1);
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> alpha_node =
data->get_element_type(), Shape{}, std::vector<double>{alpha}); std::make_shared<ngraph::op::Constant>(
alpha_node = make_broadcast_node(alpha_node, data->get_shape()); data->get_element_type(), Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape());
std::shared_ptr<ngraph::Node> zero_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> zero_node =
data->get_element_type(), Shape{}, std::vector<double>{0}); std::make_shared<ngraph::op::Constant>(
zero_node = make_broadcast_node(zero_node, data->get_shape()); data->get_element_type(), Shape{}, std::vector<double>{0});
zero_node = make_broadcast_node(zero_node, data->get_shape());
return {std::make_shared<ngraph::op::Maximum>(data, zero_node) + return {std::make_shared<ngraph::op::Maximum>(data, zero_node) +
alpha_node * std::make_shared<ngraph::op::Exp>( alpha_node *
std::make_shared<ngraph::op::Minimum>(data, zero_node)) - std::make_shared<ngraph::op::Exp>(
alpha_node}; std::make_shared<ngraph::op::Minimum>(data, zero_node)) -
} alpha_node};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,8 +26,13 @@ namespace ngraph ...@@ -26,8 +26,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector elu(const Node& node); namespace set_1
} // namespace op {
NodeVector elu(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,14 +28,18 @@ namespace ngraph ...@@ -28,14 +28,18 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector equal(const Node& node) namespace set_1
{ {
NodeVector ng_inputs{ inline NodeVector equal(const Node& node)
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())}; {
return {std::make_shared<ngraph::op::Equal>(ng_inputs.at(0), ng_inputs.at(1))}; NodeVector ng_inputs{
} numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Equal>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -29,12 +29,16 @@ namespace ngraph ...@@ -29,12 +29,16 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector exp(const Node& node) namespace set_1
{ {
return {std::make_shared<ngraph::op::Exp>(node.get_ng_inputs().at(0))}; inline NodeVector exp(const Node& node)
} {
return {std::make_shared<ngraph::op::Exp>(node.get_ng_inputs().at(0))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -25,19 +25,23 @@ namespace ngraph ...@@ -25,19 +25,23 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector flatten(const Node& node) namespace set_1
{ {
NodeVector inputs{node.get_ng_inputs()}; NodeVector flatten(const Node& node)
auto data = inputs.at(0); {
auto axis = node.get_attribute_value<int64_t>("axis", 1); NodeVector inputs{node.get_ng_inputs()};
auto data = inputs.at(0);
auto axis = node.get_attribute_value<int64_t>("axis", 1);
ASSERT_VALID_ARGUMENT(node, (axis >= 0) && (axis <= data->get_shape().size())) ASSERT_VALID_ARGUMENT(node, (axis >= 0) && (axis <= data->get_shape().size()))
<< "provided 'axis' attribute is not valid."; << "provided 'axis' attribute is not valid.";
return {reshape::flatten(data, axis)}; return {reshape::flatten(data, axis)};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,8 +28,13 @@ namespace ngraph ...@@ -28,8 +28,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector flatten(const Node& node); namespace set_1
} // namespace op {
NodeVector flatten(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -29,12 +29,16 @@ namespace ngraph ...@@ -29,12 +29,16 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector floor(const Node& node) namespace set_1
{ {
return {std::make_shared<ngraph::op::Floor>(node.get_ng_inputs().at(0))}; inline NodeVector floor(const Node& node)
} {
return {std::make_shared<ngraph::op::Floor>(node.get_ng_inputs().at(0))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -32,50 +32,58 @@ namespace ngraph ...@@ -32,50 +32,58 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector gemm(const Node& node) namespace set_1
{ {
NodeVector inputs{node.get_ng_inputs()}; NodeVector gemm(const Node& node)
auto input_a = inputs.at(0);
auto input_b = inputs.at(1);
auto input_c = inputs.at(2);
double alpha = node.get_attribute_value<double>("alpha", 1);
double beta = node.get_attribute_value<double>("beta", 1);
auto trans_a = node.get_attribute_value<int64_t>("transA", 0);
auto trans_b = node.get_attribute_value<int64_t>("transB", 0);
if (trans_a != 0)
{
input_a = reshape::transpose(input_a);
}
if (trans_b != 0)
{ {
input_b = reshape::transpose(input_b); NodeVector inputs{node.get_ng_inputs()};
auto input_a = inputs.at(0);
auto input_b = inputs.at(1);
auto input_c = inputs.at(2);
double alpha = node.get_attribute_value<double>("alpha", 1);
double beta = node.get_attribute_value<double>("beta", 1);
auto trans_a = node.get_attribute_value<int64_t>("transA", 0);
auto trans_b = node.get_attribute_value<int64_t>("transB", 0);
if (trans_a != 0)
{
input_a = reshape::transpose(input_a);
}
if (trans_b != 0)
{
input_b = reshape::transpose(input_b);
}
// code from python not implemented in c++ yet.
// reshape_for_matmul(node, input_a, input_b);
std::shared_ptr<ngraph::Node> a_dot_b =
std::make_shared<ngraph::op::Dot>(input_a, input_b);
std::shared_ptr<ngraph::Node> alpha_node =
std::make_shared<ngraph::op::Constant>(a_dot_b->get_element_type(),
ngraph::Shape{},
std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, a_dot_b->get_shape());
a_dot_b = std::make_shared<ngraph::op::Multiply>(alpha_node, a_dot_b);
std::shared_ptr<ngraph::Node> beta_node =
std::make_shared<ngraph::op::Constant>(input_c->get_element_type(),
ngraph::Shape{},
std::vector<double>{beta});
beta_node = make_broadcast_node(beta_node, input_c->get_shape());
input_c = std::make_shared<ngraph::op::Multiply>(beta_node, input_c);
input_c = make_broadcast_node(input_c, a_dot_b->get_shape());
return {std::make_shared<ngraph::op::Add>(a_dot_b, input_c)};
} }
// code from python not implemented in c++ yet. } // namespace set_1
// reshape_for_matmul(node, input_a, input_b);
std::shared_ptr<ngraph::Node> a_dot_b =
std::make_shared<ngraph::op::Dot>(input_a, input_b);
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>(
a_dot_b->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, a_dot_b->get_shape());
a_dot_b = std::make_shared<ngraph::op::Multiply>(alpha_node, a_dot_b);
std::shared_ptr<ngraph::Node> beta_node = std::make_shared<ngraph::op::Constant>(
input_c->get_element_type(), ngraph::Shape{}, std::vector<double>{beta});
beta_node = make_broadcast_node(beta_node, input_c->get_shape());
input_c = std::make_shared<ngraph::op::Multiply>(beta_node, input_c);
input_c = make_broadcast_node(input_c, a_dot_b->get_shape());
return {std::make_shared<ngraph::op::Add>(a_dot_b, input_c)};
}
} // namespace op } //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,8 +28,13 @@ namespace ngraph ...@@ -28,8 +28,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector gemm(const Node& node); namespace set_1
} // namespace op {
NodeVector gemm(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,14 +28,19 @@ namespace ngraph ...@@ -28,14 +28,19 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector greater(const Node& node) namespace set_1
{ {
NodeVector ng_inputs{ inline NodeVector greater(const Node& node)
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())}; {
return {std::make_shared<ngraph::op::Greater>(ng_inputs.at(0), ng_inputs.at(1))}; NodeVector ng_inputs{
} numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {
std::make_shared<ngraph::op::Greater>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -35,36 +35,43 @@ namespace ngraph ...@@ -35,36 +35,43 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector hard_sigmoid(const Node& node) namespace set_1
{ {
auto data = node.get_ng_inputs().at(0); NodeVector hard_sigmoid(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
double alpha = node.get_attribute_value<double>("alpha", 0.2); double alpha = node.get_attribute_value<double>("alpha", 0.2);
double beta = node.get_attribute_value<double>("beta", 0.5); double beta = node.get_attribute_value<double>("beta", 0.5);
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> alpha_node =
data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha}); std::make_shared<ngraph::op::Constant>(
alpha_node = make_broadcast_node(alpha_node, data->get_shape()); data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape());
std::shared_ptr<ngraph::Node> beta_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> beta_node =
data->get_element_type(), ngraph::Shape{}, std::vector<double>{beta}); std::make_shared<ngraph::op::Constant>(
beta_node = make_broadcast_node(beta_node, data->get_shape()); data->get_element_type(), ngraph::Shape{}, std::vector<double>{beta});
beta_node = make_broadcast_node(beta_node, data->get_shape());
std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{1}); data->get_element_type(), Shape{}, std::vector<double>{1});
one_node = make_broadcast_node(one_node, data->get_shape()); one_node = make_broadcast_node(one_node, data->get_shape());
std::shared_ptr<ngraph::Node> zero_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> zero_node =
data->get_element_type(), Shape{}, std::vector<double>{0}); std::make_shared<ngraph::op::Constant>(
zero_node = make_broadcast_node(zero_node, data->get_shape()); data->get_element_type(), Shape{}, std::vector<double>{0});
zero_node = make_broadcast_node(zero_node, data->get_shape());
return {std::make_shared<ngraph::op::Maximum>( return {std::make_shared<ngraph::op::Maximum>(
zero_node, zero_node,
std::make_shared<ngraph::op::Minimum>(one_node, std::make_shared<ngraph::op::Minimum>(one_node,
alpha_node * data + beta_node))}; alpha_node * data + beta_node))};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,8 +26,13 @@ namespace ngraph ...@@ -26,8 +26,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector hard_sigmoid(const Node& node); namespace set_1
} // namespace op {
NodeVector hard_sigmoid(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,8 +26,15 @@ namespace ngraph ...@@ -26,8 +26,15 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector identity(const Node& node) { return {node.get_ng_inputs().at(0)}; } namespace set_1
} // namespace op {
inline NodeVector identity(const Node& node)
{
return {node.get_ng_inputs().at(0)};
}
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -38,21 +38,26 @@ namespace ngraph ...@@ -38,21 +38,26 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector leaky_relu(const Node& node) namespace set_1
{ {
auto data = node.get_ng_inputs().at(0); NodeVector leaky_relu(const Node& node)
double alpha = node.get_attribute_value<double>("alpha", 0.01); {
auto data = node.get_ng_inputs().at(0);
double alpha = node.get_attribute_value<double>("alpha", 0.01);
ASSERT_VALID_ARGUMENT(node, ((alpha >= 0) && (alpha <= 1))) ASSERT_VALID_ARGUMENT(node, ((alpha >= 0) && (alpha <= 1)))
<< " alpha value should be in range (0,1)"; << " alpha value should be in range (0,1)";
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> alpha_node =
data->get_element_type(), Shape{}, std::vector<double>{alpha}); std::make_shared<ngraph::op::Constant>(
alpha_node = make_broadcast_node(alpha_node, data->get_shape()); data->get_element_type(), Shape{}, std::vector<double>{alpha});
return {std::make_shared<ngraph::op::Maximum>(data * alpha_node, data)}; alpha_node = make_broadcast_node(alpha_node, data->get_shape());
} return {std::make_shared<ngraph::op::Maximum>(data * alpha_node, data)};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,8 +26,13 @@ namespace ngraph ...@@ -26,8 +26,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector leaky_relu(const Node& node); namespace set_1
} // namespace op {
NodeVector leaky_relu(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,14 +28,18 @@ namespace ngraph ...@@ -28,14 +28,18 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector less(const Node& node) namespace set_1
{ {
NodeVector ng_inputs{ inline NodeVector less(const Node& node)
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())}; {
return {std::make_shared<ngraph::op::Less>(ng_inputs.at(0), ng_inputs.at(1))}; NodeVector ng_inputs{
} numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Less>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -29,12 +29,16 @@ namespace ngraph ...@@ -29,12 +29,16 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector log(const Node& node) namespace set_1
{ {
return {std::make_shared<ngraph::op::Log>(node.get_ng_inputs().at(0))}; inline NodeVector log(const Node& node)
} {
return {std::make_shared<ngraph::op::Log>(node.get_ng_inputs().at(0))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -31,13 +31,17 @@ namespace ngraph ...@@ -31,13 +31,17 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector log_softmax(const Node& node) namespace set_1
{ {
return {std::make_shared<ngraph::op::Log>(softmax(node).at(0))}; inline NodeVector log_softmax(const Node& node)
} {
return {std::make_shared<ngraph::op::Log>(softmax(node).at(0))};
}
} // namespace op } // namespace set_1
} // namespace onnx_import } //namespace op
} // namespace ngraph } // namespace onnx_import
} // namespace ngraph
...@@ -27,18 +27,22 @@ namespace ngraph ...@@ -27,18 +27,22 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector lrn(const Node& node) namespace set_1
{ {
auto data = node.get_ng_inputs().at(0); NodeVector lrn(const Node& node)
double alpha = node.get_attribute_value<double>("alpha", 1e-4); {
double beta = node.get_attribute_value<double>("beta", 0.75); auto data = node.get_ng_inputs().at(0);
double bias = node.get_attribute_value<double>("bias", 1); double alpha = node.get_attribute_value<double>("alpha", 1e-4);
size_t size = node.get_attribute_value<size_t>("size"); double beta = node.get_attribute_value<double>("beta", 0.75);
double bias = node.get_attribute_value<double>("bias", 1);
size_t size = node.get_attribute_value<size_t>("size");
return {std::make_shared<ngraph::op::LRN>(data, alpha, beta, bias, size)}; return {std::make_shared<ngraph::op::LRN>(data, alpha, beta, bias, size)};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,8 +26,12 @@ namespace ngraph ...@@ -26,8 +26,12 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector lrn(const Node& node); namespace set_1
} // namespace op {
NodeVector lrn(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -27,13 +27,17 @@ namespace ngraph ...@@ -27,13 +27,17 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector matmul(const Node& node) namespace set_1
{ {
NodeVector ng_inputs{node.get_ng_inputs()}; inline NodeVector matmul(const Node& node)
return {std::make_shared<ngraph::op::Dot>(ng_inputs.at(0), ng_inputs.at(1))}; {
} NodeVector ng_inputs{node.get_ng_inputs()};
return {std::make_shared<ngraph::op::Dot>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,12 +28,16 @@ namespace ngraph ...@@ -28,12 +28,16 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector max(const Node& node) namespace set_1
{ {
return variadic::make_ng_variadic_op<ngraph::op::Maximum>(node); inline NodeVector max(const Node& node)
} {
return variadic::make_ng_variadic_op<ngraph::op::Maximum>(node);
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,12 +26,16 @@ namespace ngraph ...@@ -26,12 +26,16 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector max_pool(const Node& node) namespace set_1
{ {
return convpool::make_ng_pool<ngraph::op::MaxPool>(node); NodeVector max_pool(const Node& node)
} {
return convpool::make_ng_pool<ngraph::op::MaxPool>(node);
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,17 +26,21 @@ namespace ngraph ...@@ -26,17 +26,21 @@ namespace ngraph
{ {
namespace op namespace op
{ {
/** namespace set_1
* @brief Convert ONNX MaxPool operation to an nGraph node. {
* /**
* @param node The ONNX node object representing this operation. * @brief Convert ONNX MaxPool operation to an nGraph node.
* *
* @return The vector containing Ngraph nodes producing output of ONNX MaxPool * @param node The ONNX node object representing this operation.
* operation. *
*/ * @return The vector containing Ngraph nodes producing output of ONNX MaxPool
NodeVector max_pool(const Node& node); * operation.
*/
} // namespace op NodeVector max_pool(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -27,21 +27,25 @@ namespace ngraph ...@@ -27,21 +27,25 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector mean(const Node& node) namespace set_1
{ {
auto sum = variadic::make_ng_variadic_op<ngraph::op::Add>(node).front(); NodeVector mean(const Node& node)
auto shape = sum->get_shape(); {
auto sum = variadic::make_ng_variadic_op<ngraph::op::Add>(node).front();
auto shape = sum->get_shape();
// Create a Constant representing the number of inputs with the same shape as sum // Create a Constant representing the number of inputs with the same shape as sum
auto count = ngraph::op::Constant::create( auto count = ngraph::op::Constant::create(
sum->get_element_type(), sum->get_element_type(),
shape, shape,
std::vector<int>(shape_size(shape), node.get_ng_inputs().size())); std::vector<int>(shape_size(shape), node.get_ng_inputs().size()));
return {sum / count}; return {sum / count};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,9 +26,13 @@ namespace ngraph ...@@ -26,9 +26,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector mean(const Node& node); namespace set_1
{
NodeVector mean(const Node& node);
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,12 +28,16 @@ namespace ngraph ...@@ -28,12 +28,16 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector min(const Node& node) namespace set_1
{ {
return variadic::make_ng_variadic_op<ngraph::op::Minimum>(node); inline NodeVector min(const Node& node)
} {
return variadic::make_ng_variadic_op<ngraph::op::Minimum>(node);
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,14 +28,19 @@ namespace ngraph ...@@ -28,14 +28,19 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector mul(const Node& node) namespace set_1
{ {
NodeVector ng_inputs{ inline NodeVector mul(const Node& node)
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())}; {
return {std::make_shared<ngraph::op::Multiply>(ng_inputs.at(0), ng_inputs.at(1))}; NodeVector ng_inputs{
} numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {
std::make_shared<ngraph::op::Multiply>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -27,8 +27,12 @@ namespace ngraph ...@@ -27,8 +27,12 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector neg(const Node& node) { return {-node.get_ng_inputs().at(0)}; } namespace set_1
} // namespace op {
inline NodeVector neg(const Node& node) { return {-node.get_ng_inputs().at(0)}; }
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,12 +28,16 @@ namespace ngraph ...@@ -28,12 +28,16 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector logical_not(const Node& node) namespace set_1
{ {
return {std::make_shared<ngraph::op::Not>(node.get_ng_inputs().at(0))}; inline NodeVector logical_not(const Node& node)
} {
return {std::make_shared<ngraph::op::Not>(node.get_ng_inputs().at(0))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,14 +28,18 @@ namespace ngraph ...@@ -28,14 +28,18 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector logical_or(const Node& node) namespace set_1
{ {
NodeVector ng_inputs{ inline NodeVector logical_or(const Node& node)
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())}; {
return {std::make_shared<ngraph::op::Or>(ng_inputs.at(0), ng_inputs.at(1))}; NodeVector ng_inputs{
} numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Or>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,14 +28,18 @@ namespace ngraph ...@@ -28,14 +28,18 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector pow(const Node& node) namespace set_1
{ {
NodeVector ng_inputs{ inline NodeVector pow(const Node& node)
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())}; {
return {std::make_shared<ngraph::op::Power>(ng_inputs.at(0), ng_inputs.at(1))}; NodeVector ng_inputs{
} numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::Power>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -37,31 +37,35 @@ namespace ngraph ...@@ -37,31 +37,35 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector prelu(const Node& node) namespace set_1
{ {
NodeVector ng_inputs{node.get_ng_inputs()}; NodeVector prelu(const Node& node)
auto data = ng_inputs.at(0);
auto data_shape = data->get_shape();
std::shared_ptr<ngraph::Node> slope = ng_inputs.at(1);
auto slope_shape = slope->get_shape();
if ((slope_shape.size() == 1) && (slope_shape.at(0) != 1))
{
auto it =
std::find(std::begin(data_shape), std::end(data_shape), slope_shape.at(0));
auto index = std::distance(std::begin(data_shape), it);
slope = make_broadcast_node(slope, data->get_shape(), index);
}
else
{ {
auto params = numpy_style_broadcast_for_binary_operation(slope, data); NodeVector ng_inputs{node.get_ng_inputs()};
slope = params.at(0); auto data = ng_inputs.at(0);
auto data_shape = data->get_shape();
std::shared_ptr<ngraph::Node> slope = ng_inputs.at(1);
auto slope_shape = slope->get_shape();
if ((slope_shape.size() == 1) && (slope_shape.at(0) != 1))
{
auto it = std::find(
std::begin(data_shape), std::end(data_shape), slope_shape.at(0));
auto index = std::distance(std::begin(data_shape), it);
slope = make_broadcast_node(slope, data->get_shape(), index);
}
else
{
auto params = numpy_style_broadcast_for_binary_operation(slope, data);
slope = params.at(0);
}
return {std::make_shared<ngraph::op::Maximum>(data * slope, data)};
} }
return {std::make_shared<ngraph::op::Maximum>(data * slope, data)}; } // namespace set_1
}
} // namespace op } //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,8 +26,13 @@ namespace ngraph ...@@ -26,8 +26,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector prelu(const Node& node); namespace set_1
} // namespace op {
NodeVector prelu(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -31,18 +31,22 @@ namespace ngraph ...@@ -31,18 +31,22 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector reciprocal(const Node& node) namespace set_1
{ {
auto data = node.get_ng_inputs().at(0); NodeVector reciprocal(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{1}); data->get_element_type(), Shape{}, std::vector<double>{1});
one_node = make_broadcast_node(one_node, data->get_shape()); one_node = make_broadcast_node(one_node, data->get_shape());
return {one_node / data}; return {one_node / data};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,8 +26,13 @@ namespace ngraph ...@@ -26,8 +26,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector reciprocal(const Node& node); namespace set_1
} // namespace op {
NodeVector reciprocal(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -32,29 +32,35 @@ namespace ngraph ...@@ -32,29 +32,35 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector reduce_mean(const Node& node) namespace set_1
{ {
auto input_shape = node.get_ng_inputs().at(0)->get_shape(); NodeVector reduce_mean(const Node& node)
auto reduction_axes = reduction::detail::get_reduction_axes(node); {
std::size_t elem_count_product = auto input_shape = node.get_ng_inputs().at(0)->get_shape();
std::accumulate(std::begin(reduction_axes), auto reduction_axes = reduction::detail::get_reduction_axes(node);
std::end(reduction_axes), std::size_t elem_count_product =
1UL, std::accumulate(std::begin(reduction_axes),
[&input_shape](const std::size_t& a, const std::size_t& b) { std::end(reduction_axes),
return a * input_shape.at(b); 1UL,
}); [&input_shape](const std::size_t& a, const std::size_t& b) {
auto sum_node = reduction::make_ng_reduction_op<ngraph::op::Sum>( return a * input_shape.at(b);
node, node.get_ng_inputs().at(0)); });
auto const_node = std::make_shared<ngraph::op::Constant>( auto sum_node = reduction::make_ng_reduction_op<ngraph::op::Sum>(
sum_node->get_element_type(), node, node.get_ng_inputs().at(0));
Shape{}, auto const_node = std::make_shared<ngraph::op::Constant>(
std::vector<std::size_t>{elem_count_product}); sum_node->get_element_type(),
Shape{},
auto broadcasted_const_node = std::vector<std::size_t>{elem_count_product});
make_broadcast_node(const_node, sum_node->get_shape());
return {std::make_shared<ngraph::op::Divide>(sum_node, broadcasted_const_node)}; auto broadcasted_const_node =
} make_broadcast_node(const_node, sum_node->get_shape());
return {std::make_shared<ngraph::op::Divide>(sum_node, broadcasted_const_node)};
} // namespace op }
} // namespace onnx_import
} // namespace set_1
} //namespace op
} // namespace onnx_import
} // namespace ngraph } // namespace ngraph
...@@ -27,13 +27,17 @@ namespace ngraph ...@@ -27,13 +27,17 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector relu(const Node& node) namespace set_1
{ {
NodeVector ng_inputs{node.get_ng_inputs()}; inline NodeVector relu(const Node& node)
return {std::make_shared<ngraph::op::Relu>(ng_inputs.at(0))}; {
} NodeVector ng_inputs{node.get_ng_inputs()};
return {std::make_shared<ngraph::op::Relu>(ng_inputs.at(0))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -33,39 +33,45 @@ namespace ngraph ...@@ -33,39 +33,45 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector reshape(const Node& node) namespace set_1
{ {
NodeVector ng_inputs{node.get_ng_inputs()}; NodeVector reshape(const Node& node)
auto data = ng_inputs.at(0); {
auto data_shape = data->get_shape(); NodeVector ng_inputs{node.get_ng_inputs()};
auto data = ng_inputs.at(0);
auto data_shape = data->get_shape();
auto output_shape = node.get_attribute_value<std::vector<std::size_t>>("shape", {}); auto output_shape =
node.get_attribute_value<std::vector<std::size_t>>("shape", {});
// If no shape argument (opset >= 5) and there is second input. // If no shape argument (opset >= 5) and there is second input.
if (output_shape.empty() && ng_inputs.size() == 2) if (output_shape.empty() && ng_inputs.size() == 2)
{ {
// Currently only support Constant node. // Currently only support Constant node.
ASSERT_IS_SUPPORTED(node, ng_inputs.at(1)->description() == "Constant") ASSERT_IS_SUPPORTED(node, ng_inputs.at(1)->description() == "Constant")
<< "doesn't support shape input of other type than Constant."; << "doesn't support shape input of other type than Constant.";
auto output_shape_node = auto output_shape_node =
std::dynamic_pointer_cast<ngraph::op::Constant>(ng_inputs.at(1)); std::dynamic_pointer_cast<ngraph::op::Constant>(ng_inputs.at(1));
output_shape = output_shape_node->get_vector<std::size_t>(); output_shape = output_shape_node->get_vector<std::size_t>();
} }
// Do nothing if there is no shape argument nor second node input. // Do nothing if there is no shape argument nor second node input.
else if (output_shape.empty()) else if (output_shape.empty())
{ {
return {data}; return {data};
}
output_shape =
reshape::infer_dimensions(node.get_name(), data_shape, output_shape);
return {std::make_shared<ngraph::op::Reshape>(
data,
reshape::get_default_axis_vector(data_shape.size()),
Shape{output_shape})};
} }
output_shape = reshape::infer_dimensions(node.get_name(), data_shape, output_shape); } // namespace set_1
return {std::make_shared<ngraph::op::Reshape>(
data,
reshape::get_default_axis_vector(data_shape.size()),
Shape{output_shape})};
}
} // namespace op } //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,16 +26,20 @@ namespace ngraph ...@@ -26,16 +26,20 @@ namespace ngraph
{ {
namespace op namespace op
{ {
/// namespace set_1
/// \brief Reshape the input tensor similar to numpy.reshape. {
/// ///
/// \param[in] node The ONNX node representing this operation. /// \brief Reshape the input tensor similar to numpy.reshape.
/// ///
/// \return Ngraph node representing this operation. /// \param[in] node The ONNX node representing this operation.
/// ///
NodeVector reshape(const Node& node); /// \return Ngraph node representing this operation.
///
} // namespace op NodeVector reshape(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -41,32 +41,41 @@ namespace ngraph ...@@ -41,32 +41,41 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector selu(const Node& node) namespace set_1
{ {
auto data = node.get_ng_inputs().at(0); NodeVector selu(const Node& node)
double alpha = node.get_attribute_value<double>("alpha", 1.67326319217681884765625); {
double gamma = node.get_attribute_value<double>("gamma", 1.05070102214813232421875); auto data = node.get_ng_inputs().at(0);
double alpha =
node.get_attribute_value<double>("alpha", 1.67326319217681884765625);
double gamma =
node.get_attribute_value<double>("gamma", 1.05070102214813232421875);
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> alpha_node =
data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha}); std::make_shared<ngraph::op::Constant>(
alpha_node = make_broadcast_node(alpha_node, data->get_shape()); data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape());
std::shared_ptr<ngraph::Node> gamma_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> gamma_node =
data->get_element_type(), ngraph::Shape{}, std::vector<double>{gamma}); std::make_shared<ngraph::op::Constant>(
gamma_node = make_broadcast_node(gamma_node, data->get_shape()); data->get_element_type(), ngraph::Shape{}, std::vector<double>{gamma});
gamma_node = make_broadcast_node(gamma_node, data->get_shape());
std::shared_ptr<ngraph::Node> zero_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> zero_node =
data->get_element_type(), ngraph::Shape{}, std::vector<double>{0}); std::make_shared<ngraph::op::Constant>(
zero_node = make_broadcast_node(zero_node, data->get_shape()); data->get_element_type(), ngraph::Shape{}, std::vector<double>{0});
zero_node = make_broadcast_node(zero_node, data->get_shape());
return {gamma_node * return {gamma_node * (std::make_shared<ngraph::op::Maximum>(data, zero_node) +
(std::make_shared<ngraph::op::Maximum>(data, zero_node) + alpha_node * std::make_shared<ngraph::op::Exp>(
alpha_node * std::make_shared<ngraph::op::Exp>( std::make_shared<ngraph::op::Minimum>(
std::make_shared<ngraph::op::Minimum>(data, zero_node)) - data, zero_node)) -
alpha_node)}; alpha_node)};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,8 +26,13 @@ namespace ngraph ...@@ -26,8 +26,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector selu(const Node& node); namespace set_1
} // namespace op {
NodeVector selu(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -30,16 +30,20 @@ namespace ngraph ...@@ -30,16 +30,20 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector shape(const Node& node) namespace set_1
{ {
auto data = node.get_ng_inputs().at(0); NodeVector shape(const Node& node)
auto data_shape = data->get_shape(); {
auto data = node.get_ng_inputs().at(0);
auto data_shape = data->get_shape();
return {std::make_shared<ngraph::op::Constant>( return {std::make_shared<ngraph::op::Constant>(
ngraph::element::i64, Shape{data_shape.size()}, data_shape)}; ngraph::element::i64, Shape{data_shape.size()}, data_shape)};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,8 +26,13 @@ namespace ngraph ...@@ -26,8 +26,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector shape(const Node& node); namespace set_1
} // namespace op {
NodeVector shape(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -29,12 +29,16 @@ namespace ngraph ...@@ -29,12 +29,16 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector sigmoid(const Node& node) namespace set_1
{ {
return {std::make_shared<ngraph::op::Sigmoid>(node.get_ng_inputs().at(0))}; inline NodeVector sigmoid(const Node& node)
} {
return {std::make_shared<ngraph::op::Sigmoid>(node.get_ng_inputs().at(0))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -35,32 +35,37 @@ namespace ngraph ...@@ -35,32 +35,37 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector slice(const Node& node) namespace set_1
{ {
std::shared_ptr<ngraph::Node> data = node.get_ng_inputs().at(0); NodeVector slice(const Node& node)
Shape data_shape = data->get_shape(); {
std::shared_ptr<ngraph::Node> data = node.get_ng_inputs().at(0);
Shape data_shape = data->get_shape();
auto starts = node.get_attribute_value<std::vector<int64_t>>("starts"); auto starts = node.get_attribute_value<std::vector<int64_t>>("starts");
auto ends = node.get_attribute_value<std::vector<int64_t>>("ends"); auto ends = node.get_attribute_value<std::vector<int64_t>>("ends");
auto axes = node.get_attribute_value<std::vector<int64_t>>( auto axes = node.get_attribute_value<std::vector<int64_t>>(
"axes", common::get_monotonic_range<int64_t>(data_shape.size())); "axes", common::get_monotonic_range<int64_t>(data_shape.size()));
Shape lower_bounds(data_shape.size()); Shape lower_bounds(data_shape.size());
Shape upper_bounds = data_shape; Shape upper_bounds = data_shape;
for (auto idx = 0; idx < axes.size(); ++idx) for (auto idx = 0; idx < axes.size(); ++idx)
{ {
size_t axis = axes.at(idx); size_t axis = axes.at(idx);
lower_bounds.at(axis) = lower_bounds.at(axis) =
get_valid_array_idx(starts.at(idx), data_shape.at(axis)); get_valid_array_idx(starts.at(idx), data_shape.at(axis));
upper_bounds.at(axis) = get_valid_array_idx(ends.at(idx), data_shape.at(axis)); upper_bounds.at(axis) =
get_valid_array_idx(ends.at(idx), data_shape.at(axis));
}
return {std::make_shared<ngraph::op::Slice>(data, lower_bounds, upper_bounds)};
} }
return {std::make_shared<ngraph::op::Slice>(data, lower_bounds, upper_bounds)}; } // namespace set_1
}
} // namespace op } //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,8 +26,13 @@ namespace ngraph ...@@ -26,8 +26,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector slice(const Node& node); namespace set_1
} // namespace op {
NodeVector slice(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -27,31 +27,35 @@ namespace ngraph ...@@ -27,31 +27,35 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector softmax(const Node& node) namespace set_1
{ {
NodeVector inputs{node.get_ng_inputs()}; NodeVector softmax(const Node& node)
auto data = inputs.at(0); {
auto data_shape = data->get_shape(); NodeVector inputs{node.get_ng_inputs()};
auto data = inputs.at(0);
auto data_shape = data->get_shape();
int axis = node.get_attribute_value<int64_t>("axis", 1); int axis = node.get_attribute_value<int64_t>("axis", 1);
if (axis < 0) if (axis < 0)
{ {
axis = data_shape.size() + axis; axis = data_shape.size() + axis;
} }
ASSERT_VALID_ARGUMENT(node, axis < data_shape.size()) ASSERT_VALID_ARGUMENT(node, axis < data_shape.size())
<< "provided 'axis' value:" << axis << "provided 'axis' value:" << axis
<< " is out of input tensor dimensions range."; << " is out of input tensor dimensions range.";
// create vector of capacity data_dimensions - axis_divider position
std::vector<size_t> axes(data_shape.size() - axis);
std::iota(std::begin(axes), std::end(axes), axis);
return {std::make_shared<ngraph::op::Softmax>(data, axes)};
}
// create vector of capacity data_dimensions - axis_divider position } // namespace set_1
std::vector<size_t> axes(data_shape.size() - axis);
std::iota(std::begin(axes), std::end(axes), axis);
return {std::make_shared<ngraph::op::Softmax>(data, axes)};
}
} // namespace op } //namespace op
} // namespace onnx_import } // namespace onnx_import
} // namespace ngraph } // namespace ngraph
...@@ -26,9 +26,14 @@ namespace ngraph ...@@ -26,9 +26,14 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector softmax(const Node& node); namespace set_1
} // namespace op {
NodeVector softmax(const Node& node);
} // namespace onnx_import } // namespace set_1
} // namespace ngraph } //namespace op
} // namespace onnx_import
} // namespace ngraph
...@@ -33,19 +33,23 @@ namespace ngraph ...@@ -33,19 +33,23 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector softplus(const Node& node) namespace set_1
{ {
auto data = node.get_ng_inputs().at(0); NodeVector softplus(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{1}); data->get_element_type(), Shape{}, std::vector<double>{1});
one_node = make_broadcast_node(one_node, data->get_shape()); one_node = make_broadcast_node(one_node, data->get_shape());
return {std::make_shared<ngraph::op::Log>(std::make_shared<ngraph::op::Exp>(data) + return {std::make_shared<ngraph::op::Log>(
one_node)}; std::make_shared<ngraph::op::Exp>(data) + one_node)};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,8 +26,13 @@ namespace ngraph ...@@ -26,8 +26,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector softplus(const Node& node); namespace set_1
} // namespace op {
NodeVector softplus(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -33,18 +33,22 @@ namespace ngraph ...@@ -33,18 +33,22 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector softsign(const Node& node) namespace set_1
{ {
auto data = node.get_ng_inputs().at(0); NodeVector softsign(const Node& node)
{
auto data = node.get_ng_inputs().at(0);
std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{1}); data->get_element_type(), Shape{}, std::vector<double>{1});
one_node = make_broadcast_node(one_node, data->get_shape()); one_node = make_broadcast_node(one_node, data->get_shape());
return {data / (std::make_shared<ngraph::op::Abs>(data) + one_node)}; return {data / (std::make_shared<ngraph::op::Abs>(data) + one_node)};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,8 +26,13 @@ namespace ngraph ...@@ -26,8 +26,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector softsign(const Node& node); namespace set_1
} // namespace op {
NodeVector softsign(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -80,79 +80,84 @@ namespace ngraph ...@@ -80,79 +80,84 @@ namespace ngraph
namespace op namespace op
{ {
namespace detail namespace set_1
{ {
template <typename T> namespace detail
inline T get_valid_array_index(T left, T right)
{ {
return (left >= 0) ? std::min(left, right) template <typename T>
: std::max(static_cast<T>(0), right + left); inline T get_valid_array_index(T left, T right)
} {
return (left >= 0) ? std::min(left, right)
: std::max(static_cast<T>(0), right + left);
}
inline std::shared_ptr<ngraph::op::Slice> inline std::shared_ptr<ngraph::op::Slice>
make_ng_slice(const std::shared_ptr<ngraph::Node>& node, make_ng_slice(const std::shared_ptr<ngraph::Node>& node,
std::vector<std::size_t> axes, std::vector<std::size_t> axes,
std::vector<std::size_t> starts, std::vector<std::size_t> starts,
std::vector<std::size_t> ends) std::vector<std::size_t> ends)
{
std::vector<std::size_t> upper_bounds{node->get_shape()};
std::vector<std::size_t> lower_bounds(upper_bounds.size());
for (std::size_t index{0}; index < axes.size(); ++index)
{ {
std::size_t axis{axes.at(index)}; std::vector<std::size_t> upper_bounds{node->get_shape()};
lower_bounds.at(axis) = std::vector<std::size_t> lower_bounds(upper_bounds.size());
get_valid_array_index(starts.at(index), node->get_shape().at(axis)); for (std::size_t index{0}; index < axes.size(); ++index)
upper_bounds.at(axis) = {
get_valid_array_index(ends.at(index), node->get_shape().at(axis)); std::size_t axis{axes.at(index)};
lower_bounds.at(axis) =
get_valid_array_index(starts.at(index), node->get_shape().at(axis));
upper_bounds.at(axis) =
get_valid_array_index(ends.at(index), node->get_shape().at(axis));
}
return std::make_shared<ngraph::op::Slice>(
node, lower_bounds, upper_bounds);
} }
return std::make_shared<ngraph::op::Slice>(node, lower_bounds, upper_bounds);
}
} // namespace detail } // namespace detail
NodeVector split(const Node& node) NodeVector split(const Node& node)
{
std::shared_ptr<ngraph::Node> input = node.get_ng_inputs().at(0);
std::size_t count_outputs{node.get_output_names().size()};
int64_t axis{node.get_attribute_value<int64_t>("axis", 0)};
std::size_t axis_to_split{static_cast<std::size_t>(axis)};
if (axis < 0)
{
axis_to_split = input->get_shape().size() + axis;
}
else if (axis_to_split >= input->get_shape().size())
{
throw error::op::split::OutOfRange{node.get_name()};
}
std::size_t length_axis_to_split{input->get_shape().at(axis_to_split)};
std::vector<std::size_t> length_parts;
try
{
length_parts = node.get_attribute_value<std::vector<std::size_t>>("split");
}
catch (const std::exception&)
{ {
if (length_axis_to_split % count_outputs) std::shared_ptr<ngraph::Node> input = node.get_ng_inputs().at(0);
std::size_t count_outputs{node.get_output_names().size()};
int64_t axis{node.get_attribute_value<int64_t>("axis", 0)};
std::size_t axis_to_split{static_cast<std::size_t>(axis)};
if (axis < 0)
{ {
throw error::op::split::Parts{ axis_to_split = input->get_shape().size() + axis;
node.get_name(), count_outputs, length_axis_to_split}; }
else if (axis_to_split >= input->get_shape().size())
{
throw error::op::split::OutOfRange{node.get_name()};
}
std::size_t length_axis_to_split{input->get_shape().at(axis_to_split)};
std::vector<std::size_t> length_parts;
try
{
length_parts = node.get_attribute_value<std::vector<std::size_t>>("split");
}
catch (const std::exception&)
{
if (length_axis_to_split % count_outputs)
{
throw error::op::split::Parts{
node.get_name(), count_outputs, length_axis_to_split};
}
length_parts.assign(count_outputs, length_axis_to_split / count_outputs);
} }
length_parts.assign(count_outputs, length_axis_to_split / count_outputs);
}
std::size_t start_index{0}; std::size_t start_index{0};
NodeVector outputs; NodeVector outputs;
for (const auto& length_part : length_parts) for (const auto& length_part : length_parts)
{ {
std::size_t end_index{start_index + length_part}; std::size_t end_index{start_index + length_part};
outputs.push_back( outputs.push_back(detail::make_ng_slice(
detail::make_ng_slice(input, {axis_to_split}, {start_index}, {end_index})); input, {axis_to_split}, {start_index}, {end_index}));
start_index = end_index; start_index = end_index;
}
return outputs;
} }
return outputs;
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,8 +26,13 @@ namespace ngraph ...@@ -26,8 +26,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector split(const Node& node); namespace set_1
} // namespace op {
NodeVector split(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -29,12 +29,16 @@ namespace ngraph ...@@ -29,12 +29,16 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector sqrt(const Node& node) namespace set_1
{ {
return {std::make_shared<ngraph::op::Sqrt>(node.get_ng_inputs().at(0))}; inline NodeVector sqrt(const Node& node)
} {
return {std::make_shared<ngraph::op::Sqrt>(node.get_ng_inputs().at(0))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -30,38 +30,42 @@ namespace ngraph ...@@ -30,38 +30,42 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector squeeze(const Node& node) namespace set_1
{ {
NodeVector inputs{node.get_ng_inputs()}; NodeVector squeeze(const Node& node)
auto data = inputs.at(0);
auto data_shape = data->get_shape();
auto axes = node.get_attribute_value<std::vector<uint64_t>>("axes", {});
if (axes.empty())
{ {
for (auto index = 0; index < data_shape.size(); ++index) NodeVector inputs{node.get_ng_inputs()};
auto data = inputs.at(0);
auto data_shape = data->get_shape();
auto axes = node.get_attribute_value<std::vector<uint64_t>>("axes", {});
if (axes.empty())
{ {
if (data_shape.at(index) == 1) for (auto index = 0; index < data_shape.size(); ++index)
{ {
axes.push_back(index); if (data_shape.at(index) == 1)
{
axes.push_back(index);
}
} }
} }
}
std::sort(std::begin(axes), std::end(axes), std::greater<uint64_t>()); std::sort(std::begin(axes), std::end(axes), std::greater<uint64_t>());
AxisVector input_order{reshape::get_default_axis_vector(data_shape.size())}; AxisVector input_order{reshape::get_default_axis_vector(data_shape.size())};
for (auto axis : axes) for (auto axis : axes)
{ {
data_shape.erase(std::next(std::begin(data_shape), axis)); data_shape.erase(std::next(std::begin(data_shape), axis));
}
return {std::make_shared<ngraph::op::Reshape>(data, input_order, data_shape)};
} }
return {std::make_shared<ngraph::op::Reshape>(data, input_order, data_shape)}; } // namespace set_1
}
} // namespace op } //namespace op
} // namespace onnx_import } // namespace onnx_import
} // namespace ngraph } // namespace ngraph
...@@ -26,9 +26,14 @@ namespace ngraph ...@@ -26,9 +26,14 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector squeeze(const Node& node); namespace set_1
} // namespace op {
NodeVector squeeze(const Node& node);
} // namespace onnx_import } // namespace set_1
} // namespace ngraph } //namespace op
} // namespace onnx_import
} // namespace ngraph
...@@ -28,14 +28,19 @@ namespace ngraph ...@@ -28,14 +28,19 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector sub(const Node& node) namespace set_1
{ {
NodeVector ng_inputs{ inline NodeVector sub(const Node& node)
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())}; {
return {std::make_shared<ngraph::op::Subtract>(ng_inputs.at(0), ng_inputs.at(1))}; NodeVector ng_inputs{
} numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
return {
std::make_shared<ngraph::op::Subtract>(ng_inputs.at(0), ng_inputs.at(1))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,12 +28,16 @@ namespace ngraph ...@@ -28,12 +28,16 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector sum(const Node& node) namespace set_1
{ {
return variadic::make_ng_variadic_op<ngraph::op::Add>(node); inline NodeVector sum(const Node& node)
} {
return variadic::make_ng_variadic_op<ngraph::op::Add>(node);
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -29,12 +29,16 @@ namespace ngraph ...@@ -29,12 +29,16 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector tanh(const Node& node) namespace set_1
{ {
return {std::make_shared<ngraph::op::Tanh>(node.get_ng_inputs().at(0))}; inline NodeVector tanh(const Node& node)
} {
return {std::make_shared<ngraph::op::Tanh>(node.get_ng_inputs().at(0))};
}
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -36,22 +36,27 @@ namespace ngraph ...@@ -36,22 +36,27 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector thresholded_relu(const Node& node) namespace set_1
{ {
auto data = node.get_ng_inputs().at(0); NodeVector thresholded_relu(const Node& node)
double alpha = node.get_attribute_value<double>("alpha", 1.0); {
auto data = node.get_ng_inputs().at(0);
double alpha = node.get_attribute_value<double>("alpha", 1.0);
std::shared_ptr<ngraph::Node> alpha_node = std::make_shared<ngraph::op::Constant>( std::shared_ptr<ngraph::Node> alpha_node =
data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha}); std::make_shared<ngraph::op::Constant>(
alpha_node = make_broadcast_node(alpha_node, data->get_shape()); data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha});
alpha_node = make_broadcast_node(alpha_node, data->get_shape());
auto data_map = std::make_shared<ngraph::op::Convert>( auto data_map = std::make_shared<ngraph::op::Convert>(
std::make_shared<ngraph::op::Greater>(data, alpha_node), std::make_shared<ngraph::op::Greater>(data, alpha_node),
data->get_element_type()); data->get_element_type());
return {data * data_map}; return {data * data_map};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,8 +26,13 @@ namespace ngraph ...@@ -26,8 +26,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector thresholded_relu(const Node& node); namespace set_1
} // namespace op {
NodeVector thresholded_relu(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,17 +28,22 @@ namespace ngraph ...@@ -28,17 +28,22 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector transpose(const Node& node) namespace set_1
{ {
std::shared_ptr<ngraph::Node> data = node.get_ng_inputs().at(0); NodeVector transpose(const Node& node)
{
std::shared_ptr<ngraph::Node> data = node.get_ng_inputs().at(0);
auto permute_axes = node.get_attribute_value<std::vector<std::size_t>>("perm", {}); auto permute_axes =
node.get_attribute_value<std::vector<std::size_t>>("perm", {});
return {(permute_axes.empty()) ? reshape::transpose(data) return {(permute_axes.empty()) ? reshape::transpose(data)
: reshape::reorder_axes(data, permute_axes)}; : reshape::reorder_axes(data, permute_axes)};
} }
} // namespace op } // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -26,8 +26,13 @@ namespace ngraph ...@@ -26,8 +26,13 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector transpose(const Node& node); namespace set_1
} // namespace op {
NodeVector transpose(const Node& node);
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
...@@ -28,32 +28,36 @@ namespace ngraph ...@@ -28,32 +28,36 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector unsqueeze(const Node& node) namespace set_1
{ {
NodeVector inputs{node.get_ng_inputs()}; NodeVector unsqueeze(const Node& node)
auto data = inputs.at(0); {
auto data_shape = data->get_shape(); NodeVector inputs{node.get_ng_inputs()};
auto axes = node.get_attribute_value<std::vector<int64_t>>("axes"); auto data = inputs.at(0);
auto data_shape = data->get_shape();
auto axes = node.get_attribute_value<std::vector<int64_t>>("axes");
ASSERT_VALID_ARGUMENT(node, !axes.empty()) << "'axes' attribute is mandatory."; ASSERT_VALID_ARGUMENT(node, !axes.empty()) << "'axes' attribute is mandatory.";
std::sort(std::begin(axes), std::end(axes), std::less<int64_t>()); std::sort(std::begin(axes), std::end(axes), std::less<int64_t>());
AxisVector input_order{reshape::get_default_axis_vector(data_shape.size())}; AxisVector input_order{reshape::get_default_axis_vector(data_shape.size())};
for (auto axis : axes) for (auto axis : axes)
{ {
ASSERT_VALID_ARGUMENT(node, axis >= 0 && axis <= data_shape.size()) ASSERT_VALID_ARGUMENT(node, axis >= 0 && axis <= data_shape.size())
<< "provided 'axes' attribute is not valid."; << "provided 'axes' attribute is not valid.";
data_shape.insert(std::next(std::begin(data_shape), axis), 1);
}
data_shape.insert(std::next(std::begin(data_shape), axis), 1); return {std::make_shared<ngraph::op::Reshape>(data, input_order, data_shape)};
} }
return {std::make_shared<ngraph::op::Reshape>(data, input_order, data_shape)}; } // namespace set_1
}
} // namespace op } //namespace op
} // namespace onnx_import } // namespace onnx_import
} // namespace ngraph } // namespace ngraph
...@@ -26,9 +26,14 @@ namespace ngraph ...@@ -26,9 +26,14 @@ namespace ngraph
{ {
namespace op namespace op
{ {
NodeVector unsqueeze(const Node& node); namespace set_1
} // namespace op {
NodeVector unsqueeze(const Node& node);
} // namespace onnx_import } // namespace set_1
} // namespace ngraph } //namespace op
} // namespace onnx_import
} // namespace ngraph
...@@ -30,20 +30,24 @@ namespace ngraph ...@@ -30,20 +30,24 @@ namespace ngraph
{ {
namespace op namespace op
{ {
inline NodeVector logical_xor(const Node& node) namespace set_1
{ {
NodeVector ng_inputs{ inline NodeVector logical_xor(const Node& node)
numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())}; {
auto left = ng_inputs.at(0); NodeVector ng_inputs{
auto not_left = std::make_shared<ngraph::op::Not>(left); numpy_style_broadcast_for_binary_operation(node.get_ng_inputs())};
auto right = ng_inputs.at(1); auto left = ng_inputs.at(0);
auto not_right = std::make_shared<ngraph::op::Not>(right); auto not_left = std::make_shared<ngraph::op::Not>(left);
return {std::make_shared<ngraph::op::Or>( auto right = ng_inputs.at(1);
std::make_shared<ngraph::op::And>(left, not_right), auto not_right = std::make_shared<ngraph::op::Not>(right);
std::make_shared<ngraph::op::And>(not_left, right))}; return {std::make_shared<ngraph::op::Or>(
} std::make_shared<ngraph::op::And>(left, not_right),
std::make_shared<ngraph::op::And>(not_left, right))};
} // namespace op }
} // namespace set_1
} //namespace op
} // namespace onnx_import } // namespace onnx_import
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment