Commit e0c3400b authored by Jaikrishnan Menon's avatar Jaikrishnan Menon

Merge branch 'master' into mkldnn-compile

parents fc5cda7c fbf5a0cf
...@@ -26,20 +26,24 @@ endif() ...@@ -26,20 +26,24 @@ endif()
project (ngraph) project (ngraph)
SET( GCC_MIN_VERSION 4.8) SET(GCC_MIN_VERSION 4.8)
SET( CLANG_MIN_VERSION 3.8) SET(CLANG_MIN_VERSION 3.8)
SET(APPLE_CLANG_MIN_VERSION 9.0)
if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
if (CMAKE_CXX_COMPILER_VERSION VERSION_LESS GCC_MIN_VERSION) if (CMAKE_CXX_COMPILER_VERSION VERSION_LESS GCC_MIN_VERSION)
message(FATAL_ERROR "GCC version must be at least ${GCC_MIN_VERSION}!") message(FATAL_ERROR "GCC version must be at least ${GCC_MIN_VERSION}!")
endif() endif()
elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang")
# require at least clang 3.8
if (CMAKE_CXX_COMPILER_VERSION VERSION_LESS CLANG_MIN_VERSION) if (CMAKE_CXX_COMPILER_VERSION VERSION_LESS CLANG_MIN_VERSION)
message(FATAL_ERROR "Clang version must be at least ${CLANG_MIN_VERSION}!") message(FATAL_ERROR "Clang version must be at least ${CLANG_MIN_VERSION}!")
endif() endif()
elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "AppleClang")
if (CMAKE_CXX_COMPILER_VERSION VERSION_LESS APPLE_CLANG_MIN_VERSION)
message(FATAL_ERROR "Apple Clang version must be at least ${APPLE_CLANG_MIN_VERSION}!")
endif()
else() else()
message(WARNING "You are using an unsupported compiler! Compilation has only been tested with Clang ( ${CLANG_MIN_VERSION} and up) and GCC ( ${GCC_MIN_VERSION} and up). ") message(WARNING "You are using an unsupported compiler. Compilation has only been tested with Clang (${CLANG_MIN_VERSION} and up), Apple Clang (${APPLE_CLANG_MIN_VERSION} and up), and GCC (${GCC_MIN_VERSION} and up).")
endif() endif()
if($ENV{NGRAPH_USE_PREBUILT_LLVM}) if($ENV{NGRAPH_USE_PREBUILT_LLVM})
...@@ -133,19 +137,29 @@ elseif(NGRAPH_GPU_ENABLE) ...@@ -133,19 +137,29 @@ elseif(NGRAPH_GPU_ENABLE)
message(FATAL_ERROR "GPU was required but CUDA library was not found") message(FATAL_ERROR "GPU was required but CUDA library was not found")
endif() endif()
#-----------------------------------------------------------------------------------------------
# distributed support
#-----------------------------------------------------------------------------------------------
if(NGRAPH_DISTRIBUTED_ENABLE)
find_package(MPI REQUIRED)
if(MPI_CXX_FOUND)
add_definitions(-DNGRAPH_DISTRIBUTED)
endif()
endif()
#----------------------------------------------------------------------------------------------- #-----------------------------------------------------------------------------------------------
# External projects install directory # External projects install directory
#----------------------------------------------------------------------------------------------- #-----------------------------------------------------------------------------------------------
# Root for external installs, when using `make`, e.g. # Root for external installs, when using `make`, e.g.
# ${EXTERNAL_INSTALL_DIR} # ${EXTERNAL_PROJECTS_ROOT}
# ├── eigen <-- ${EIGEN_INSTALL_DIR} # ├── eigen
# │   ├── include # │   ├── include
# │   │   └── eigen3 <-- ${EIGEN_INCLUDE_DIR} # │   │   └── eigen3 <-- ${EIGEN_INCLUDE_DIR}
# │   │   ├── Eigen # │   │   ├── Eigen
# │ ... # │ ...
# │ # │
# └── mkldnn <-- ${MKLDNN_INSTALL_DIR} # └── mkldnn
# ├── include <-- ${MKLDNN_INCLUDE_DIR} # ├── include <-- ${MKLDNN_INCLUDE_DIR}
# │   ├── mkldnn.h # │   ├── mkldnn.h
# │ ... # │ ...
......
# Intel® nGraph™ library project # Intel® nGraph™ library
Welcome to the Intel nGraph project, an open source C++ library for developers Welcome to Intel nGraph, an open source C++ library for developers of Deep
of Deep Learning (DL) systems. Here you will find a suite of components, APIs, Learning (DL) systems. Here you will find a suite of components, APIs, and
and documentation that can be used to compile and run Deep Neural Network (DNN) documentation that can be used to compile and run Deep Neural Network (DNN)
models defined in a variety of frameworks. models defined in a variety of frameworks.
The nGraph library translates a framework’s representation of computations into The nGraph library translates a framework’s representation of computations into
...@@ -14,7 +14,8 @@ and data layout abstraction. ...@@ -14,7 +14,8 @@ and data layout abstraction.
See our [install] docs for how to get started. See our [install] docs for how to get started.
For this early release, we provide framework integration guides to compile For this early release, we provide [framework integration guides] to compile
MXNet and TensorFlow-based projects. MXNet and TensorFlow-based projects.
[install]: http://ngraph.nervanasys.com/docs/cpp/installation.html [install]: http://ngraph.nervanasys.com/docs/cpp/installation.html
\ No newline at end of file [framework integration guides]:http://ngraph.nervanasys.com/docs/cpp/framework-integration-guides.html
# API Changes # API Changes
## Changes to ops
* The namespace `ngraph::op` is only for actual ops. Helpers have been moved into
`ngraph::op::util`:
+ `BinaryElementwiseArithmetic`
+ `BinaryElementwiseComparison`
+ `BinaryElementwise`
+ `RequiresTensorViewArgs`
+ `UnaryElementwiseArithmetic`
+ `UnaryElementwise`
Ops defined outside of nGraph core will need to get the base class from `ngraph::op::util` and
change the include file to `#include "ngraph/ops/util/requires_tensor_view_args.hpp"`, etc.
See any of the core ops for an example.
## Changes to convolution and pooling ops ## Changes to convolution and pooling ops
* Backprop ops have been added for convolution ops. * Backprop ops have been added for convolution ops.
......
...@@ -17,8 +17,6 @@ ...@@ -17,8 +17,6 @@
# Enable ExternalProject CMake module # Enable ExternalProject CMake module
include(ExternalProject) include(ExternalProject)
set(EIGEN_INSTALL_DIR ${EXTERNAL_INSTALL_DIR}/eigen)
set(EIGEN_PROJECT eigen)
set(EIGEN_GIT_TAG d608d9f3f577118981acbdd40da9dcf6b514668a) set(EIGEN_GIT_TAG d608d9f3f577118981acbdd40da9dcf6b514668a)
set(EIGEN_GIT_URL https://github.com/jmenon/eigen) set(EIGEN_GIT_URL https://github.com/jmenon/eigen)
...@@ -29,25 +27,35 @@ set(EIGEN_GIT_URL https://github.com/jmenon/eigen) ...@@ -29,25 +27,35 @@ set(EIGEN_GIT_URL https://github.com/jmenon/eigen)
# The 'BUILD_BYPRODUCTS' argument was introduced in CMake 3.2. # The 'BUILD_BYPRODUCTS' argument was introduced in CMake 3.2.
if (${CMAKE_VERSION} VERSION_LESS 3.2) if (${CMAKE_VERSION} VERSION_LESS 3.2)
ExternalProject_Add( ExternalProject_Add(
${EIGEN_PROJECT} ext_eigen
GIT_REPOSITORY ${EIGEN_GIT_URL} GIT_REPOSITORY ${EIGEN_GIT_URL}
GIT_TAG ${EIGEN_GIT_TAG} GIT_TAG ${EIGEN_GIT_TAG}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${EIGEN_INSTALL_DIR} -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${EXTERNAL_PROJECTS_ROOT}/eigen -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
TMP_DIR "${EXTERNAL_PROJECTS_ROOT}/eigen/tmp"
STAMP_DIR "${EXTERNAL_PROJECTS_ROOT}/eigen/stamp"
DOWNLOAD_DIR "${EXTERNAL_PROJECTS_ROOT}/eigen/download"
SOURCE_DIR "${EXTERNAL_PROJECTS_ROOT}/eigen/src"
BINARY_DIR "${EXTERNAL_PROJECTS_ROOT}/eigen/build"
INSTALL_DIR "${EXTERNAL_PROJECTS_ROOT}/eigen"
) )
else() else()
ExternalProject_Add( ExternalProject_Add(
${EIGEN_PROJECT} ext_eigen
GIT_REPOSITORY ${EIGEN_GIT_URL} GIT_REPOSITORY ${EIGEN_GIT_URL}
GIT_TAG ${EIGEN_GIT_TAG} GIT_TAG ${EIGEN_GIT_TAG}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${EIGEN_INSTALL_DIR} -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${EXTERNAL_PROJECTS_ROOT}/eigen -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
BUILD_BYPRODUCTS "${EIGEN_INSTALL_DIR}/include/eigen3" TMP_DIR "${EXTERNAL_PROJECTS_ROOT}/eigen/tmp"
STAMP_DIR "${EXTERNAL_PROJECTS_ROOT}/eigen/stamp"
DOWNLOAD_DIR "${EXTERNAL_PROJECTS_ROOT}/eigen/download"
SOURCE_DIR "${EXTERNAL_PROJECTS_ROOT}/eigen/src"
BINARY_DIR "${EXTERNAL_PROJECTS_ROOT}/eigen/build"
INSTALL_DIR "${EXTERNAL_PROJECTS_ROOT}/eigen"
BUILD_BYPRODUCTS "${EXTERNAL_PROJECTS_ROOT}/eigen/include/eigen3"
) )
endif() endif()
#---------------------------------------------------------------------------------------------------------- #----------------------------------------------------------------------------------------------------------
ExternalProject_Get_Property(eigen source_dir binary_dir) set(EIGEN_INCLUDE_DIR "${EXTERNAL_PROJECTS_ROOT}/eigen/include/eigen3" PARENT_SCOPE)
set(EIGEN_INCLUDE_DIR "${EIGEN_INSTALL_DIR}/include/eigen3" PARENT_SCOPE)
...@@ -27,46 +27,53 @@ SET(GTEST_GIT_LABEL release-1.8.0) ...@@ -27,46 +27,53 @@ SET(GTEST_GIT_LABEL release-1.8.0)
# The 'BUILD_BYPRODUCTS' argument was introduced in CMake 3.2. # The 'BUILD_BYPRODUCTS' argument was introduced in CMake 3.2.
if (${CMAKE_VERSION} VERSION_LESS 3.2) if (${CMAKE_VERSION} VERSION_LESS 3.2)
ExternalProject_Add( ExternalProject_Add(
gtest ext_gtest
GIT_REPOSITORY ${GTEST_GIT_REPO_URL} GIT_REPOSITORY ${GTEST_GIT_REPO_URL}
GIT_TAG ${GTEST_GIT_LABEL} GIT_TAG ${GTEST_GIT_LABEL}
PREFIX ${CMAKE_CURRENT_BINARY_DIR}/gtest
# Disable install step # Disable install step
INSTALL_COMMAND "" INSTALL_COMMAND ""
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
TMP_DIR "${EXTERNAL_PROJECTS_ROOT}/gtest/tmp"
STAMP_DIR "${EXTERNAL_PROJECTS_ROOT}/gtest/stamp"
DOWNLOAD_DIR "${EXTERNAL_PROJECTS_ROOT}/gtest/download"
SOURCE_DIR "${EXTERNAL_PROJECTS_ROOT}/gtest/src"
BINARY_DIR "${EXTERNAL_PROJECTS_ROOT}/gtest/build"
INSTALL_DIR "${EXTERNAL_PROJECTS_ROOT}/gtest"
) )
else() else()
ExternalProject_Add( ExternalProject_Add(
gtest ext_gtest
GIT_REPOSITORY ${GTEST_GIT_REPO_URL} GIT_REPOSITORY ${GTEST_GIT_REPO_URL}
GIT_TAG ${GTEST_GIT_LABEL} GIT_TAG ${GTEST_GIT_LABEL}
PREFIX ${CMAKE_CURRENT_BINARY_DIR}/gtest
# Disable install step # Disable install step
INSTALL_COMMAND "" INSTALL_COMMAND ""
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
BUILD_BYPRODUCTS "${CMAKE_CURRENT_BINARY_DIR}/gtest/src/gtest-build/googlemock/gtest/libgtest.a" TMP_DIR "${EXTERNAL_PROJECTS_ROOT}/gtest/tmp"
STAMP_DIR "${EXTERNAL_PROJECTS_ROOT}/gtest/stamp"
DOWNLOAD_DIR "${EXTERNAL_PROJECTS_ROOT}/gtest/download"
SOURCE_DIR "${EXTERNAL_PROJECTS_ROOT}/gtest/src"
BINARY_DIR "${EXTERNAL_PROJECTS_ROOT}/gtest/build"
INSTALL_DIR "${EXTERNAL_PROJECTS_ROOT}/gtest"
BUILD_BYPRODUCTS "${EXTERNAL_PROJECTS_ROOT}/gtest/build/googlemock/gtest/libgtest.a"
) )
endif() endif()
#---------------------------------------------------------------------------------------------------------- #----------------------------------------------------------------------------------------------------------
# Get GTest source and binary directories from CMake project
ExternalProject_Get_Property(gtest source_dir binary_dir)
get_filename_component( get_filename_component(
GTEST_INCLUDE_DIR GTEST_INCLUDE_DIR
"${CMAKE_CURRENT_BINARY_DIR}/gtest/src/gtest/googletest/include" "${EXTERNAL_PROJECTS_ROOT}/gtest/src/googletest/include"
ABSOLUTE) ABSOLUTE)
set(GTEST_INCLUDE_DIR "${GTEST_INCLUDE_DIR}" PARENT_SCOPE) set(GTEST_INCLUDE_DIR "${GTEST_INCLUDE_DIR}" PARENT_SCOPE)
# Create a libgtest target to be used as a dependency by test programs # Create a libgtest target to be used as a dependency by test programs
add_library(libgtest IMPORTED STATIC GLOBAL) add_library(libgtest IMPORTED STATIC GLOBAL)
add_dependencies(libgtest gtest) add_dependencies(libgtest ext_gtest)
# Set libgtest properties # Set libgtest properties
set_target_properties(libgtest PROPERTIES set_target_properties(libgtest PROPERTIES
"IMPORTED_LOCATION" "${binary_dir}/googlemock/gtest/libgtest.a" "IMPORTED_LOCATION" "${EXTERNAL_PROJECTS_ROOT}/gtest/build/googlemock/gtest/libgtest.a"
"IMPORTED_LINK_INTERFACE_LIBRARIES" "${CMAKE_THREAD_LIBS_INIT}" "IMPORTED_LINK_INTERFACE_LIBRARIES" "${CMAKE_THREAD_LIBS_INIT}"
) )
This diff is collapsed.
...@@ -20,11 +20,10 @@ include(ExternalProject) ...@@ -20,11 +20,10 @@ include(ExternalProject)
# Fetch and install MKL-DNN # Fetch and install MKL-DNN
#---------------------------------------------------------------------------------------------------------- #----------------------------------------------------------------------------------------------------------
if(NGRAPH_CPU_ENABLE AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin") if(NGRAPH_CPU_ENABLE)
set(MKLDNN_GIT_REPO_URL https://github.com/intel/mkl-dnn) set(MKLDNN_GIT_REPO_URL https://github.com/intel/mkl-dnn)
set(MKLDNN_GIT_TAG "3e1f8f5") set(MKLDNN_GIT_TAG "3e1f8f5")
set(MKLDNN_INSTALL_DIR ${EXTERNAL_INSTALL_DIR}/mkldnn)
# The 'BUILD_BYPRODUCTS' argument was introduced in CMake 3.2. # The 'BUILD_BYPRODUCTS' argument was introduced in CMake 3.2.
if(${CMAKE_VERSION} VERSION_LESS 3.2) if(${CMAKE_VERSION} VERSION_LESS 3.2)
...@@ -38,7 +37,13 @@ if(NGRAPH_CPU_ENABLE AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin") ...@@ -38,7 +37,13 @@ if(NGRAPH_CPU_ENABLE AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
CMAKE_ARGS CMAKE_ARGS
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_INSTALL_PREFIX=${MKLDNN_INSTALL_DIR} -DCMAKE_INSTALL_PREFIX=${EXTERNAL_PROJECTS_ROOT}/mkldnn
TMP_DIR "${EXTERNAL_PROJECTS_ROOT}/mkldnn/tmp"
STAMP_DIR "${EXTERNAL_PROJECTS_ROOT}/mkldnn/stamp"
DOWNLOAD_DIR "${EXTERNAL_PROJECTS_ROOT}/mkldnn/download"
SOURCE_DIR "${EXTERNAL_PROJECTS_ROOT}/mkldnn/src"
BINARY_DIR "${EXTERNAL_PROJECTS_ROOT}/mkldnn/build"
INSTALL_DIR "${EXTERNAL_PROJECTS_ROOT}/mkldnn"
) )
else() else()
ExternalProject_Add( ExternalProject_Add(
...@@ -51,8 +56,14 @@ if(NGRAPH_CPU_ENABLE AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin") ...@@ -51,8 +56,14 @@ if(NGRAPH_CPU_ENABLE AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
CMAKE_ARGS CMAKE_ARGS
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_INSTALL_PREFIX=${MKLDNN_INSTALL_DIR} -DCMAKE_INSTALL_PREFIX=${EXTERNAL_PROJECTS_ROOT}/mkldnn
BUILD_BYPRODUCTS "${MKLDNN_INSTALL_DIR}/include/mkldnn.hpp" TMP_DIR "${EXTERNAL_PROJECTS_ROOT}/mkldnn/tmp"
STAMP_DIR "${EXTERNAL_PROJECTS_ROOT}/mkldnn/stamp"
DOWNLOAD_DIR "${EXTERNAL_PROJECTS_ROOT}/mkldnn/download"
SOURCE_DIR "${EXTERNAL_PROJECTS_ROOT}/mkldnn/src"
BINARY_DIR "${EXTERNAL_PROJECTS_ROOT}/mkldnn/build"
INSTALL_DIR "${EXTERNAL_PROJECTS_ROOT}/mkldnn"
BUILD_BYPRODUCTS "${EXTERNAL_PROJECTS_ROOT}/mkldnn/include/mkldnn.hpp"
) )
endif() endif()
...@@ -67,7 +78,7 @@ if(NGRAPH_CPU_ENABLE AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin") ...@@ -67,7 +78,7 @@ if(NGRAPH_CPU_ENABLE AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
) )
set(MKLDNN_INCLUDE_DIR "${MKLDNN_INSTALL_DIR}/include" PARENT_SCOPE) set(MKLDNN_INCLUDE_DIR "${EXTERNAL_PROJECTS_ROOT}/mkldnn/include" PARENT_SCOPE)
set(MKLDNN_LIB_DIR "${MKLDNN_INSTALL_DIR}/lib" PARENT_SCOPE) set(MKLDNN_LIB_DIR "${EXTERNAL_PROJECTS_ROOT}/mkldnn/lib" PARENT_SCOPE)
endif() endif()
...@@ -22,13 +22,11 @@ if(NGRAPH_TBB_ENABLE) ...@@ -22,13 +22,11 @@ if(NGRAPH_TBB_ENABLE)
set(TBB_GIT_REPO_URL https://github.com/01org/tbb) set(TBB_GIT_REPO_URL https://github.com/01org/tbb)
set(TBB_GIT_TAG "tbb_2018") set(TBB_GIT_TAG "tbb_2018")
if(NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin") configure_file(${CMAKE_SOURCE_DIR}/cmake/tbb_fetch.cmake.in ${CMAKE_CURRENT_BINARY_DIR}/tbb/CMakeLists.txt)
configure_file(${CMAKE_SOURCE_DIR}/cmake/tbb_fetch.cmake.in ${CMAKE_CURRENT_BINARY_DIR}/tbb/CMakeLists.txt) execute_process(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" .
execute_process(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tbb")
WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tbb") execute_process(COMMAND "${CMAKE_COMMAND}" --build .
execute_process(COMMAND "${CMAKE_COMMAND}" --build . WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tbb")
WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tbb")
set(TBB_ROOT ${CMAKE_CURRENT_BINARY_DIR}/tbb/tbb-src PARENT_SCOPE) set(TBB_ROOT ${CMAKE_CURRENT_BINARY_DIR}/tbb/tbb-src PARENT_SCOPE)
endif()
endif() endif()
...@@ -79,19 +79,9 @@ build_docker_image: expand_dockerfile_templates ...@@ -79,19 +79,9 @@ build_docker_image: expand_dockerfile_templates
build_docker: build_docker_image build_docker: build_docker_image
# Build docs # Build docs
docs: sphinx_doc doxygen_doc docs: sphinx_doc
doxygen_doc: sphinx_doc: build_docker_image
# doxygen html docs build
docker run --rm --tty \
${VOLUME} \
${DOCKER_RUN_ENV} \
--env RUN_UID="$(shell id -u)" \
--env RUN_CMD="set -e ; set -o pipefail ; cd ${DOCKUSER_HOME}/ngraph-cpp-test/doc/doxygen; env VERBOSE=1 make html 2>&1 | tee make_sphinx_html.log" \
"build_ngraph_cpp:${DBUILD_VERSION}" \
sh -c "${DOCKUSER_HOME}/ngraph-cpp-test/contrib/docker/run_as_user.sh"
sphinx_doc:
# sphinx html docs build # sphinx html docs build
docker run --rm --tty \ docker run --rm --tty \
${VOLUME} \ ${VOLUME} \
......
...@@ -19,12 +19,6 @@ help: ...@@ -19,12 +19,6 @@ help:
%: Makefile %: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
doxy-code:
$(Q)(cat ngraph.doxyfile ; echo "STRIP_FROM_PATH=${NGRAPH_BASE}" ) | doxygen - 2>&1 | tee doc.log
doxy: doxy-code
clean: clean:
@rm -rf $(BUILDDIR)/* @rm -rf $(BUILDDIR)/*
@rm -rf html @rm -rf html
...@@ -32,7 +26,10 @@ clean: ...@@ -32,7 +26,10 @@ clean:
@rm -rf doxygen @rm -rf doxygen
@rm -rf latex @rm -rf latex
htmldocs: doxy html doxy-code:
$(Q)(cat ngraph.doxyfile ; echo "STRIP_FROM_PATH=${NGRAPH_BASE}" ) | doxygen - 2>&1 | tee doc.log
html: doxy-code
pickle: pickle:
$(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle
...@@ -41,8 +38,6 @@ pickle: ...@@ -41,8 +38,6 @@ pickle:
json: prep json: prep
$(SPHINXBUILD) -t $(DOC_TAG) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json $(SPHINXBUILD) -t $(DOC_TAG) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json
@rm -rf samples
@rm -rf boards
@echo @echo
@echo "Build finished; now you can process the JSON files." @echo "Build finished; now you can process the JSON files."
......
...@@ -17,10 +17,11 @@ ...@@ -17,10 +17,11 @@
Intel nGraph library project Intel nGraph library project
############################# #############################
Welcome to the Intel nGraph project, an open source C++ library for developers Welcome to Intel nGraph, an open source C++ library for developers of
of :abbr:`Deep Learning (DL)` (DL) systems. Here you will find a suite of :abbr:`Deep Learning (DL)` (DL) systems. Here you will find a suite
components, APIs, and documentation that can be used to compile and run of components, APIs, and documentation that can be used to compile
:abbr:`Deep Neural Network (DNN)` (DNN) models defined in a variety of frameworks. and run :abbr:`Deep Neural Network (DNN)` (DNN) models defined in a
variety of frameworks.
.. figure:: graphics/ngraph-hub.png .. figure:: graphics/ngraph-hub.png
......
...@@ -7,24 +7,23 @@ Install the Intel® nGraph™ library ...@@ -7,24 +7,23 @@ Install the Intel® nGraph™ library
Build Environments Build Environments
================== ==================
The |release| version of |project| supports Linux\* or UNIX-based The |release| version of |project| supports Linux\*-based systems which
systems which have recent updates of the following packages and have recent updates of the following packages and prerequisites:
prerequisites:
.. csv-table:: .. csv-table::
:header: "Operating System", "Compiler", "Build System", "Status", "Additional Packages" :header: "Operating System", "Compiler", "Build System", "Status", "Additional Packages"
:widths: 25, 15, 25, 20, 25 :widths: 25, 15, 25, 20, 25
:escape: ~ :escape: ~
CentOS 7.4 64-bit, CLang 3.4, GCC 4.8 + CMake 2.8, supported, ``patch diffutils zlib1g-dev libtinfo-dev`` CentOS 7.4 64-bit, GCC 4.8, CMake 3.2, supported, ``patch diffutils zlib1g-dev libtinfo-dev``
Ubuntu 16.04 (LTS) 64-bit, CLang 3.9, CMake 3.5.1 + GNU Make, supported, ``build-essential cmake clang-3.9 git libtinfo-dev`` Ubuntu 16.04 (LTS) 64-bit, CLang 3.9, CMake 3.5.1 + GNU Make, supported, ``build-essential cmake clang-3.9 git libtinfo-dev``
Ubuntu 16.04 (LTS) 64-bit, CLang 4.0, CMake 3.5.1 + GNU Make, officially unsupported, ``build-essential cmake clang-4.0 git libtinfo-dev``
Clear Linux\* OS for Intel Architecture, CLang 5.0.1, CMake 3.10.2, experimental, bundles ``machine-learning-basic dev-utils python3-basic python-basic-dev`` Clear Linux\* OS for Intel Architecture, CLang 5.0.1, CMake 3.10.2, experimental, bundles ``machine-learning-basic dev-utils python3-basic python-basic-dev``
On Ubuntu 16.04 with ``gcc-5.4.0`` or ``clang-3.9``, the recommended option Other configurations may work, but aren't tested; on Ubuntu 16.04 with
is to add ``-DNGRAPH_USE_PREBUILT_LLVM=TRUE`` to the :command:`cmake` command. ``gcc-5.4.0`` or ``clang-3.9``, for example, we recommend adding
This gets a pre-built tarball of LLVM+Clang from `llvm.org`_, and substantially ``-DNGRAPH_USE_PREBUILT_LLVM=TRUE`` to the :command:`cmake` command in step 4
reduces build times. below. This gets a pre-built tarball of LLVM+Clang from `llvm.org`_, and will
substantially reduce build time.
If using ``gcc-4.8``, it may be necessary to add symlinksfrom ``gcc`` to If using ``gcc-4.8``, it may be necessary to add symlinksfrom ``gcc`` to
``gcc-4.8``, and from ``g++`` to ``g++-4.8``, in your :envvar:`PATH`, even ``gcc-4.8``, and from ``g++`` to ``g++-4.8``, in your :envvar:`PATH`, even
...@@ -33,7 +32,7 @@ flags when building. (You should NOT supply the `-DNGRAPH_USE_PREBUILT_LLVM` ...@@ -33,7 +32,7 @@ flags when building. (You should NOT supply the `-DNGRAPH_USE_PREBUILT_LLVM`
flag in this case, because the prebuilt tarball supplied on llvm.org is not flag in this case, because the prebuilt tarball supplied on llvm.org is not
compatible with a gcc-4.8 based build.) compatible with a gcc-4.8 based build.)
Support for macOS is limited; see the macOS development prerequisites Support for macOS is limited; see the `macOS development prerequisites`_
section at the end of this page for details. section at the end of this page for details.
...@@ -95,6 +94,7 @@ information about how to change or customize this location. ...@@ -95,6 +94,7 @@ information about how to change or customize this location.
the ``doc/sphinx`` directory to build HTML API docs inside the the ``doc/sphinx`` directory to build HTML API docs inside the
``/docs/doxygen/`` directory. ``/docs/doxygen/`` directory.
.. macos_development_prerequisites:
macOS Development Prerequisites macOS Development Prerequisites
------------------------------- -------------------------------
......
.. constant.rst:
########
Constant
########
Description
===========
Literal constant tensor.
The output is a tensor initialized from the ``values`` attribute.
Attributes
----------
+-----------------+------------------------------+---------------------------------------+
| Name | Type | Notes |
+=================+==============================+=======================================+
| ``type`` | ``ngraph::element::type`` | The element type of the value |
| | | in the computation. |
+-----------------+------------------------------+---------------------------------------+
| ``shape`` | ``ngraph::Shape`` | The shape of the constant. |
+-----------------+------------------------------+---------------------------------------+
| ``values`` | ``const std::vector<T>&`` | Constant elements in row-major order. |
| | | T must be compatible with the element |
| | | type. |
+-----------------+------------------------------+---------------------------------------+
Outputs
-------
+-----------------+-------------------------+--------------------------------+
| Name | Element Type | Shape |
+=================+=========================+================================+
| ``output`` | ``type`` | ``shape`` |
+-----------------+-------------------------+--------------------------------+
C++ Interface
=============
.. doxygenclass:: ngraph::op::Constant
:members:
.. convert.rst:
#######
Convert
#######
Description
===========
Convert a tensor from one element type to another.
Inputs
------
+-----------------+-------------------------+--------------------------------+
| Name | Element Type | Shape |
+=================+=========================+================================+
| ``arg`` | Any | Any |
+-----------------+-------------------------+--------------------------------+
Attributes
----------
+------------------+---------------------------+---------------------------------+
| Name | Type | Notes |
+==================+===========================+=================================+
| ``element_type`` | ``ngraph::element::type`` | The element type of the result. |
+------------------+---------------------------+---------------------------------+
Outputs
-------
+-----------------+-------------------------+--------------------------------+
| Name | Element Type | Shape |
+=================+=========================+================================+
| ``output`` | ``element_type`` | Same as ``arg``. |
+-----------------+-------------------------+--------------------------------+
Backprop
========
.. math::
\overline{\texttt{arg}} \leftarrow \texttt{Convert}(\Delta,\texttt{arg->get_element_type()})
C++ Interface
=============
.. doxygenclass:: ngraph::op::Convert
:members:
...@@ -125,3 +125,4 @@ C++ Interface ...@@ -125,3 +125,4 @@ C++ Interface
.. doxygenclass:: ngraph::op::Convolution .. doxygenclass:: ngraph::op::Convolution
:members: :members:
\ No newline at end of file
.. cos.rst:
###
Cos
###
Description
===========
Elementwise cosine operation.
Produces a tensor of the same element type and shape as ``arg``,
where the value at each coordinate of ``output`` is the cosine of the
value at the corresponding coordinate of ``arg``.
Inputs
------
+-----------------+-------------------------+--------------------------------+
| Name | Element Type | Shape |
+=================+=========================+================================+
| ``arg`` | Any | Any |
+-----------------+-------------------------+--------------------------------+
Outputs
-------
+-----------------+-------------------------+--------------------------------+
| Name | Element Type | Shape |
+=================+=========================+================================+
| ``output`` | Same as ``arg`` | Same as ``arg``. |
+-----------------+-------------------------+--------------------------------+
Mathematical Definition
=======================
.. math::
\texttt{output}_{i_0, \ldots, i_{n-1}} = \cos(\texttt{arg}_{i_0, \ldots, i_{n-1}})
Backprop
========
.. math::
\overline{\texttt{arg}} \leftarrow -\Delta\ \sin(\texttt{arg})
C++ Interface
=============
.. doxygenclass:: ngraph::op::Cos
:members:
.. cosh.rst:
####
Cosh
####
Description
===========
Elementwise hyperbolic cosine operation.
Produces a tensor of the same element type and shape as ``arg``, where
the value at each coordinate of ``output`` is the hyperbolic cosine of
the value at the corresponding coordinate of ``arg``.
Inputs
------
+-----------------+-------------------------+--------------------------------+
| Name | Element Type | Shape |
+=================+=========================+================================+
| ``arg`` | Any | Any |
+-----------------+-------------------------+--------------------------------+
Outputs
-------
+-----------------+-------------------------+--------------------------------+
| Name | Element Type | Shape |
+=================+=========================+================================+
| ``output`` | Same as ``arg`` | Same as ``arg``. |
+-----------------+-------------------------+--------------------------------+
Mathematical Definition
=======================
.. math::
\texttt{output}_{i_0, \ldots, i_{n-1}} = \cosh(\texttt{arg}_{i_0, \ldots, i_{n-1}})
Backprop
========
.. math::
\overline{\texttt{arg}} \leftarrow \Delta\ \sinh(\texttt{arg})
C++ Interface
=============
.. doxygenclass:: ngraph::op::Cosh
:members:
...@@ -58,4 +58,9 @@ Not currently a comprehensive list. ...@@ -58,4 +58,9 @@ Not currently a comprehensive list.
broadcast.rst broadcast.rst
ceiling.rst ceiling.rst
concatenate.rst concatenate.rst
constant.rst
convert.rst
convolution.rst convolution.rst
cos.rst
cosh.rst
...@@ -17,6 +17,30 @@ Architecture CPUs, the Intel® Nervana Neural Network Processor™ (NNP), ...@@ -17,6 +17,30 @@ Architecture CPUs, the Intel® Nervana Neural Network Processor™ (NNP),
and NVIDIA\* GPUs. Currently-supported compiler optimizations include efficient and NVIDIA\* GPUs. Currently-supported compiler optimizations include efficient
memory management and data layout abstraction. memory management and data layout abstraction.
Why is this needed?
--------------------
When Deep Learning (DL) frameworks first emerged as the vehicle for training
and inference models, they were designed around kernels optimized for a
particular platform. As a result, many backend details were being exposed in
the model definitions, making the adaptability and portability of DL models
to other or more advanced backends inherently complex and expensive.
The traditional approach means that an algorithm developer cannot easily adapt
his or her model to different backends. Making a model run on a different
framework is also problematic because the user must separate the essence of
the model from the performance adjustments made for the backend, translate
to similar ops in the new framework, and finally make the necessary changes
for the preferred backend configuration on the new framework.
We designed the Intel nGraph project to substantially reduce these kinds of
engineering complexities. While optimized kernels for deep-learning primitives
are provided through the project and via libraries like Intel® Math Kernel
Library (Intel® MKL) for Deep Neural Networks (Intel® MKL-DNN), there are
several compiler-inspired ways in which performance can be further optimized.
=======
The *nGraph core* uses a strongly-typed and platform-neutral stateless graph The *nGraph core* uses a strongly-typed and platform-neutral stateless graph
representation for computations. Each node, or *op*, in the graph corresponds representation for computations. Each node, or *op*, in the graph corresponds
to one step in a computation, where each step produces zero or more tensor to one step in a computation, where each step produces zero or more tensor
...@@ -39,4 +63,7 @@ read more about design decisions and what is tentatively in the pipeline ...@@ -39,4 +63,7 @@ read more about design decisions and what is tentatively in the pipeline
for development in our `SysML conference paper`_. for development in our `SysML conference paper`_.
.. _frontend: http://neon.nervanasys.com/index.html/ .. _frontend: http://neon.nervanasys.com/index.html/
.. _SysML conference paper: https://arxiv.org/pdf/1801.08058.pdf .. _SysML conference paper: https://arxiv.org/pdf/1801.08058.pdf
\ No newline at end of file .. _MXNet: http://mxnet.incubator.apache.org/
.. _TensorFlow: https://www.tensorflow.org/
...@@ -31,14 +31,12 @@ training/inference model with one of the backends that are now enabled. ...@@ -31,14 +31,12 @@ training/inference model with one of the backends that are now enabled.
For this early |release| release, we're providing :doc:`framework-integration-guides`, For this early |release| release, we're providing :doc:`framework-integration-guides`,
for: for:
* :doc:`framework-integration-guides` framework, * :doc:`MXNet<framework-integration-guides>` framework,
* :doc:`framework-integration-guides` framework, and * :doc:`Tensorflow<framework-integration-guides>` framework, and
* neon™ `frontend framework`_. * neon™ `frontend framework`_.
Integration guides for other frameworks are tentatively forthcoming. Integration guides for other frameworks are tentatively forthcoming.
.. _GTest framework: https://github.com/google/googletest.git .. _GTest framework: https://github.com/google/googletest.git
.. _MXNet: http://mxnet.incubator.apache.org/
.. _TensorFlow: https://www.tensorflow.org/
.. _frontend framework: http://neon.nervanasys.com/index.html/ .. _frontend framework: http://neon.nervanasys.com/index.html/
...@@ -36,9 +36,7 @@ set (SRC ...@@ -36,9 +36,7 @@ set (SRC
ops/abs.cpp ops/abs.cpp
ops/add.cpp ops/add.cpp
ops/avg_pool.cpp ops/avg_pool.cpp
ops/binary_elementwise_arithmetic.cpp ops/batch_norm.cpp
ops/binary_elementwise_comparison.cpp
ops/binary_elementwise.cpp
ops/broadcast.cpp ops/broadcast.cpp
ops/concatenate.cpp ops/concatenate.cpp
ops/constant.cpp ops/constant.cpp
...@@ -78,8 +76,12 @@ set (SRC ...@@ -78,8 +76,12 @@ set (SRC
ops/sum.cpp ops/sum.cpp
ops/tan.cpp ops/tan.cpp
ops/tanh.cpp ops/tanh.cpp
ops/unary_elementwise_arithmetic.cpp ops/util/binary_elementwise_arithmetic.cpp
ops/unary_elementwise.cpp ops/util/binary_elementwise_comparison.cpp
ops/util/binary_elementwise.cpp
ops/util/requires_tensor_view_args.cpp
ops/util/unary_elementwise_arithmetic.cpp
ops/util/unary_elementwise.cpp
pass/dump_sorted.cpp pass/dump_sorted.cpp
pass/graph_rewrite.cpp pass/graph_rewrite.cpp
pass/inliner.cpp pass/inliner.cpp
...@@ -204,6 +206,19 @@ if (NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR AND ...@@ -204,6 +206,19 @@ if (NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR AND
set_source_files_properties(codegen/compiler.cpp PROPERTIES COMPILE_DEFINITIONS "${HEADER_SEARCH_DEFINES}") set_source_files_properties(codegen/compiler.cpp PROPERTIES COMPILE_DEFINITIONS "${HEADER_SEARCH_DEFINES}")
set(NGRAPH_CPU_DEBUGINFO_ENABLE 0 CACHE STRING "Enable debuginfo in the CPU backend") set(NGRAPH_CPU_DEBUGINFO_ENABLE 0 CACHE STRING "Enable debuginfo in the CPU backend")
if(NGRAPH_DISTRIBUTED_ENABLE AND MPI_CXX_INCLUDE_PATH)
include_directories(SYSTEM ${MPI_C_INCLUDE_PATH} ${MPI_CXX_INCLUDE_PATH})
link_directories(${MPI_C_LIBRARIES} ${MPI_CXX_LIBRARIES})
# Add sources for distributed ngraph
# and all its dependencies
set(SRC ${SRC}
ops/allreduce.cpp
)
set_property(SOURCE codegen/compiler.cpp APPEND PROPERTY COMPILE_DEFINITIONS
"MPI_HEADER_PATH=\"${MPI_C_INCLUDE_PATH}\";")
endif()
# GPU backend current requires CPU because they share compiler.cpp, # GPU backend current requires CPU because they share compiler.cpp,
# and compiler.cpp requires MKLDNN # and compiler.cpp requires MKLDNN
if(NGRAPH_GPU_ENABLE) if(NGRAPH_GPU_ENABLE)
...@@ -238,7 +253,7 @@ if (NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR AND MKLDNN_INCLUDE_DIR) ...@@ -238,7 +253,7 @@ if (NGRAPH_CPU_ENABLE AND LLVM_INCLUDE_DIR AND MKLDNN_INCLUDE_DIR)
# Generate the resource file containing all headers used by the codegen compiler # Generate the resource file containing all headers used by the codegen compiler
add_custom_target(header_resource add_custom_target(header_resource
resource_generator --output ${CMAKE_BINARY_DIR}/header_resource.hpp --base codegen resource_generator --output ${CMAKE_BINARY_DIR}/header_resource.hpp --base codegen
DEPENDS resource_generator eigen ext_llvm ext_mkldnn DEPENDS resource_generator ext_eigen ext_llvm ext_mkldnn
BYPRODUCTS BYPRODUCTS
) )
add_dependencies(ngraph header_resource) add_dependencies(ngraph header_resource)
...@@ -293,6 +308,10 @@ if(NGRAPH_GPU_ENABLE AND CUDA_LIBRARIES) ...@@ -293,6 +308,10 @@ if(NGRAPH_GPU_ENABLE AND CUDA_LIBRARIES)
target_link_libraries(ngraph PRIVATE ${CUDA_LIBRARIES} ${CUDA_CUBLAS_LIBRARIES} ${CUDNN_LIBRARIES}) target_link_libraries(ngraph PRIVATE ${CUDA_LIBRARIES} ${CUDA_CUBLAS_LIBRARIES} ${CUDNN_LIBRARIES})
endif() endif()
if(MPI_CXX_INCLUDE_PATH)
target_link_libraries(ngraph PRIVATE ${MPI_CXX_LIBRARIES})
endif()
# Argon # Argon
if (NGRAPH_ARGON_ENABLE) if (NGRAPH_ARGON_ENABLE)
target_link_libraries(ngraph PRIVATE ${ARGON_TRANSFORMER_LIB_DIR}/libargon.so) target_link_libraries(ngraph PRIVATE ${ARGON_TRANSFORMER_LIB_DIR}/libargon.so)
...@@ -313,35 +332,33 @@ install(DIRECTORY ...@@ -313,35 +332,33 @@ install(DIRECTORY
FILES_MATCHING PATTERN "*.hpp" FILES_MATCHING PATTERN "*.hpp"
) )
if (NOT APPLE) install(DIRECTORY
install(DIRECTORY ${MKLDNN_LIB_DIR}/
${MKLDNN_LIB_DIR}/ DESTINATION "${NGRAPH_INSTALL_LIB}"
DESTINATION "${NGRAPH_INSTALL_LIB}" )
)
if (NGRAPH_TBB_ENABLE) if (NGRAPH_TBB_ENABLE)
install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/tbb_build/tbb_release/ install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/tbb_build/tbb_release/
DESTINATION ${NGRAPH_INSTALL_LIB} DESTINATION ${NGRAPH_INSTALL_LIB}
FILES_MATCHING PATTERN "libtbb.so.*" FILES_MATCHING PATTERN "libtbb.so.*"
) )
install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/tbb_build/tbb_debug/ install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/tbb_build/tbb_debug/
DESTINATION ${NGRAPH_INSTALL_LIB} DESTINATION ${NGRAPH_INSTALL_LIB}
FILES_MATCHING PATTERN "libtbb_debug.so.*" FILES_MATCHING PATTERN "libtbb_debug.so.*"
) )
endif() endif()
if (NGRAPH_ARGON_ENABLE) if (NGRAPH_ARGON_ENABLE)
install(DIRECTORY ${ARGON_TRANSFORMER_LIB_DIR}/ install(DIRECTORY ${ARGON_TRANSFORMER_LIB_DIR}/
DESTINATION ${NGRAPH_INSTALL_LIB} DESTINATION ${NGRAPH_INSTALL_LIB}
FILES_MATCHING PATTERN "*.so" FILES_MATCHING PATTERN "*.so"
) )
install(DIRECTORY ${ARGON_TRANSFORMER_INCLUDE_DIR}/ install(DIRECTORY ${ARGON_TRANSFORMER_INCLUDE_DIR}/
DESTINATION ${NGRAPH_INSTALL_INCLUDE} DESTINATION ${NGRAPH_INSTALL_INCLUDE}
FILES_MATCHING PATTERN "*.hpp" FILES_MATCHING PATTERN "*.hpp"
) )
install(DIRECTORY ${ARGON_TRANSFORMER_INCLUDE_DIR}/ install(DIRECTORY ${ARGON_TRANSFORMER_INCLUDE_DIR}/
DESTINATION ${NGRAPH_INSTALL_INCLUDE} DESTINATION ${NGRAPH_INSTALL_INCLUDE}
FILES_MATCHING PATTERN "*.h" FILES_MATCHING PATTERN "*.h"
) )
endif()
endif() endif()
...@@ -151,6 +151,8 @@ void codegen::StaticCompiler::initialize() ...@@ -151,6 +151,8 @@ void codegen::StaticCompiler::initialize()
// Prepare DiagnosticEngine // Prepare DiagnosticEngine
IntrusiveRefCntPtr<DiagnosticOptions> diag_options = new DiagnosticOptions(); IntrusiveRefCntPtr<DiagnosticOptions> diag_options = new DiagnosticOptions();
diag_options->ErrorLimit = 20; diag_options->ErrorLimit = 20;
diag_options->ShowCarets = false;
diag_options->ShowFixits = false;
IntrusiveRefCntPtr<DiagnosticIDs> diag_id(new DiagnosticIDs()); IntrusiveRefCntPtr<DiagnosticIDs> diag_id(new DiagnosticIDs());
DiagnosticsEngine diag_engine(diag_id, &*diag_options); DiagnosticsEngine diag_engine(diag_id, &*diag_options);
...@@ -206,19 +208,8 @@ void codegen::StaticCompiler::initialize() ...@@ -206,19 +208,8 @@ void codegen::StaticCompiler::initialize()
} }
// Enable various target features // Enable various target features
// Most of these are for Eigen
auto& TO = m_compiler->getInvocation().getTargetOpts(); auto& TO = m_compiler->getInvocation().getTargetOpts();
TO.CPU = sys::getHostCPUName(); TO.CPU = sys::getHostCPUName();
TO.FeaturesAsWritten.emplace_back("+sse");
TO.FeaturesAsWritten.emplace_back("+sse2");
TO.FeaturesAsWritten.emplace_back("+sse3");
TO.FeaturesAsWritten.emplace_back("+ssse3");
TO.FeaturesAsWritten.emplace_back("+sse4.1");
TO.FeaturesAsWritten.emplace_back("+sse4.2");
TO.FeaturesAsWritten.emplace_back("+avx");
TO.FeaturesAsWritten.emplace_back("+avx2");
TO.FeaturesAsWritten.emplace_back("+fma");
} }
codegen::StaticCompiler::~StaticCompiler() codegen::StaticCompiler::~StaticCompiler()
...@@ -351,6 +342,15 @@ void codegen::StaticCompiler::configure_search_path() ...@@ -351,6 +342,15 @@ void codegen::StaticCompiler::configure_search_path()
{ {
#ifdef USE_BUILTIN #ifdef USE_BUILTIN
load_headers_from_resource(); load_headers_from_resource();
#elif defined(__APPLE__)
add_header_search_path(EIGEN_HEADERS_PATH);
add_header_search_path(MKLDNN_HEADERS_PATH);
add_header_search_path(TBB_HEADERS_PATH);
add_header_search_path(NGRAPH_HEADERS_PATH);
add_header_search_path(INSTALLED_HEADERS_PATH);
add_header_search_path(CLANG_BUILTIN_HEADERS_PATH);
add_header_search_path("/Library/Developer/CommandLineTools/usr/include/c++/v1");
#else #else
// Add base toolchain-supplied header paths // Add base toolchain-supplied header paths
// Ideally one would use the Linux toolchain definition in clang/lib/Driver/ToolChains.h // Ideally one would use the Linux toolchain definition in clang/lib/Driver/ToolChains.h
...@@ -400,6 +400,10 @@ void codegen::StaticCompiler::configure_search_path() ...@@ -400,6 +400,10 @@ void codegen::StaticCompiler::configure_search_path()
// Only needed for GPU backend // Only needed for GPU backend
add_header_search_path(CUDA_HEADER_PATHS); add_header_search_path(CUDA_HEADER_PATHS);
#endif #endif
#ifdef NGRAPH_DISTRIBUTED
add_header_search_path(MPI_HEADER_PATH);
#endif
} }
void codegen::StaticCompiler::load_headers_from_resource() void codegen::StaticCompiler::load_headers_from_resource()
......
...@@ -78,6 +78,13 @@ void codegen::ExecutionEngine::finalize() ...@@ -78,6 +78,13 @@ void codegen::ExecutionEngine::finalize()
void* codegen::ExecutionEngine::get_pointer_to_named_function(const std::string& func_name) void* codegen::ExecutionEngine::get_pointer_to_named_function(const std::string& func_name)
{ {
// For whatever reason, macOS seems to expect that we prefix this with an underscore.
#ifdef __APPLE__
std::string fname = "_" + func_name;
#else
const std::string& fname = func_name;
#endif
// set AbortOnFailure flag to false so call fails by returning nullptr // set AbortOnFailure flag to false so call fails by returning nullptr
return m_execution_engine->getPointerToNamedFunction(func_name, false); return m_execution_engine->getPointerToNamedFunction(fname, false);
} }
...@@ -26,45 +26,45 @@ ...@@ -26,45 +26,45 @@
using namespace std; using namespace std;
namespace nervana namespace ngraph
{ {
class thread_starter; class thread_starter;
} }
string nervana::logger::log_path; string ngraph::logger::log_path;
deque<string> nervana::logger::queue; deque<string> ngraph::logger::queue;
static mutex queue_mutex; static mutex queue_mutex;
static condition_variable queue_condition; static condition_variable queue_condition;
static unique_ptr<thread> queue_thread; static unique_ptr<thread> queue_thread;
static bool active = false; static bool active = false;
std::ostream& nervana::get_nil_stream() std::ostream& ngraph::get_nil_stream()
{ {
static std::stringstream nil; static std::stringstream nil;
return nil; return nil;
} }
class nervana::thread_starter class ngraph::thread_starter
{ {
public: public:
thread_starter() { nervana::logger::start(); } thread_starter() { ngraph::logger::start(); }
virtual ~thread_starter() { nervana::logger::stop(); } virtual ~thread_starter() { ngraph::logger::stop(); }
}; };
static nervana::thread_starter _starter; static ngraph::thread_starter _starter;
void nervana::logger::set_log_path(const string& path) void ngraph::logger::set_log_path(const string& path)
{ {
log_path = path; log_path = path;
} }
void nervana::logger::start() void ngraph::logger::start()
{ {
active = true; active = true;
queue_thread = unique_ptr<thread>(new thread(&thread_entry, nullptr)); queue_thread = unique_ptr<thread>(new thread(&thread_entry, nullptr));
} }
void nervana::logger::stop() void ngraph::logger::stop()
{ {
{ {
unique_lock<std::mutex> lk(queue_mutex); unique_lock<std::mutex> lk(queue_mutex);
...@@ -74,12 +74,12 @@ void nervana::logger::stop() ...@@ -74,12 +74,12 @@ void nervana::logger::stop()
queue_thread->join(); queue_thread->join();
} }
void nervana::logger::process_event(const string& s) void ngraph::logger::process_event(const string& s)
{ {
cout << s << "\n"; cout << s << "\n";
} }
void nervana::logger::thread_entry(void* param) void ngraph::logger::thread_entry(void* param)
{ {
unique_lock<std::mutex> lk(queue_mutex); unique_lock<std::mutex> lk(queue_mutex);
while (active) while (active)
...@@ -93,14 +93,14 @@ void nervana::logger::thread_entry(void* param) ...@@ -93,14 +93,14 @@ void nervana::logger::thread_entry(void* param)
} }
} }
void nervana::logger::log_item(const string& s) void ngraph::logger::log_item(const string& s)
{ {
unique_lock<std::mutex> lk(queue_mutex); unique_lock<std::mutex> lk(queue_mutex);
queue.push_back(s); queue.push_back(s);
queue_condition.notify_one(); queue_condition.notify_one();
} }
nervana::log_helper::log_helper(LOG_TYPE type, const char* file, int line, const char* func) ngraph::log_helper::log_helper(LOG_TYPE type, const char* file, int line, const char* func)
{ {
switch (type) switch (type)
{ {
...@@ -124,7 +124,7 @@ nervana::log_helper::log_helper(LOG_TYPE type, const char* file, int line, const ...@@ -124,7 +124,7 @@ nervana::log_helper::log_helper(LOG_TYPE type, const char* file, int line, const
_stream << "\t"; _stream << "\t";
} }
nervana::log_helper::~log_helper() ngraph::log_helper::~log_helper()
{ {
cout << _stream.str() << endl; cout << _stream.str() << endl;
// logger::log_item(_stream.str()); // logger::log_item(_stream.str());
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
namespace nervana namespace ngraph
{ {
class conststring class conststring
{ {
...@@ -93,30 +93,30 @@ namespace nervana ...@@ -93,30 +93,30 @@ namespace nervana
extern std::ostream& get_nil_stream(); extern std::ostream& get_nil_stream();
#define NGRAPH_ERR \ #define NGRAPH_ERR \
nervana::log_helper(nervana::LOG_TYPE::_LOG_TYPE_ERROR, \ ngraph::log_helper(ngraph::LOG_TYPE::_LOG_TYPE_ERROR, \
nervana::get_file_name(__FILE__), \ ngraph::get_file_name(__FILE__), \
__LINE__, \ __LINE__, \
__PRETTY_FUNCTION__) \ __PRETTY_FUNCTION__) \
.stream() .stream()
#define NGRAPH_WARN \ #define NGRAPH_WARN \
nervana::log_helper(nervana::LOG_TYPE::_LOG_TYPE_WARNING, \ ngraph::log_helper(ngraph::LOG_TYPE::_LOG_TYPE_WARNING, \
nervana::get_file_name(__FILE__), \ ngraph::get_file_name(__FILE__), \
__LINE__, \ __LINE__, \
__PRETTY_FUNCTION__) \ __PRETTY_FUNCTION__) \
.stream() .stream()
#define NGRAPH_INFO \ #define NGRAPH_INFO \
nervana::log_helper(nervana::LOG_TYPE::_LOG_TYPE_INFO, \ ngraph::log_helper(ngraph::LOG_TYPE::_LOG_TYPE_INFO, \
nervana::get_file_name(__FILE__), \ ngraph::get_file_name(__FILE__), \
__LINE__, \ __LINE__, \
__PRETTY_FUNCTION__) \ __PRETTY_FUNCTION__) \
.stream() .stream()
// #define NGRAPH_DEBUG \ // #define NGRAPH_DEBUG \
// nervana::log_helper(nervana::LOG_TYPE::_LOG_TYPE_DEBUG, \ // ngraph::log_helper(ngraph::LOG_TYPE::_LOG_TYPE_DEBUG, \
// nervana::get_file_name(__FILE__), \ // ngraph::get_file_name(__FILE__), \
// __LINE__, \ // __LINE__, \
// __PRETTY_FUNCTION__) \ // __PRETTY_FUNCTION__) \
// .stream() // .stream()
#define NGRAPH_DEBUG nervana::get_nil_stream() #define NGRAPH_DEBUG ngraph::get_nil_stream()
} }
...@@ -127,3 +127,7 @@ ...@@ -127,3 +127,7 @@
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "ngraph/types/element_type.hpp" #include "ngraph/types/element_type.hpp"
#include "ngraph/types/type.hpp" #include "ngraph/types/type.hpp"
#ifdef NGRAPH_DISTRIBUTED
#include "ngraph/ops/allreduce.hpp"
#endif
...@@ -144,42 +144,6 @@ void Node::set_name(const string& name) ...@@ -144,42 +144,6 @@ void Node::set_name(const string& name)
} }
} }
void Node::assert_argument_list_equivalency(const Nodes& b)
{
bool arguments_equal = true;
if (this->m_arguments.size() == b.size())
{
for (size_t i = 0; i < this->m_arguments.size(); i++)
{
arguments_equal = arguments_equal && this->m_arguments.at(i) == b.at(i);
}
}
else
{
arguments_equal = false;
}
if (!arguments_equal)
{
std::cout << "node = " << this->get_name() << std::endl;
std::cout << "m_arguments" << std::endl;
for (auto arg : this->m_arguments)
{
std::cout << "arg = " << arg->get_name() << std::endl;
}
std::cout << "results" << std::endl;
for (auto arg : b)
{
std::cout << "arg = " << arg->get_name() << std::endl;
}
}
if (!arguments_equal)
{
throw "Arguments aren't equal";
}
}
std::shared_ptr<Node> Node::get_input_op(size_t index) std::shared_ptr<Node> Node::get_input_op(size_t index)
{ {
for (auto arg : m_arguments) for (auto arg : m_arguments)
...@@ -201,7 +165,10 @@ Nodes Node::get_input_ops() //const ...@@ -201,7 +165,10 @@ Nodes Node::get_input_ops() //const
result.push_back(i.get_output().get_node()); result.push_back(i.get_output().get_node());
} }
} }
assert_argument_list_equivalency(result); if (m_arguments != result)
{
throw ngraph_error("Arguments aren't equal: different values");
}
return result; return result;
} }
......
...@@ -170,8 +170,6 @@ namespace ngraph ...@@ -170,8 +170,6 @@ namespace ngraph
protected: protected:
void add_output(const element::Type& element_type, const Shape& shape); void add_output(const element::Type& element_type, const Shape& shape);
void assert_argument_list_equivalency(const Nodes& b);
bool test_identical(const Node&) const;
std::string m_node_type; std::string m_node_type;
std::multiset<Node*> m_users; std::multiset<Node*> m_users;
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include <memory> #include <memory>
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -26,7 +26,7 @@ namespace ngraph ...@@ -26,7 +26,7 @@ namespace ngraph
{ {
/// \brief Elementwise absolute value operation. /// \brief Elementwise absolute value operation.
/// ///
class Abs : public UnaryElementwiseArithmetic class Abs : public util::UnaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs an absolute value operation. /// \brief Constructs an absolute value operation.
...@@ -45,7 +45,9 @@ namespace ngraph ...@@ -45,7 +45,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 1) if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Abs>(new_args.at(0)); return std::make_shared<Abs>(new_args.at(0));
} }
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include <memory> #include <memory>
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -26,7 +26,7 @@ namespace ngraph ...@@ -26,7 +26,7 @@ namespace ngraph
{ {
/// \brief Elementwise inverse cosine (arccos) operation. /// \brief Elementwise inverse cosine (arccos) operation.
/// ///
class Acos : public UnaryElementwiseArithmetic class Acos : public util::UnaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs an arccos operation. /// \brief Constructs an arccos operation.
...@@ -45,7 +45,9 @@ namespace ngraph ...@@ -45,7 +45,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 1) if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Acos>(new_args.at(0)); return std::make_shared<Acos>(new_args.at(0));
} }
}; };
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include <memory> #include <memory>
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/binary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -26,7 +26,7 @@ namespace ngraph ...@@ -26,7 +26,7 @@ namespace ngraph
{ {
/// \brief Elementwise addition operation. /// \brief Elementwise addition operation.
/// ///
class Add : public BinaryElementwiseArithmetic class Add : public util::BinaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs an addition operation. /// \brief Constructs an addition operation.
...@@ -47,7 +47,9 @@ namespace ngraph ...@@ -47,7 +47,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 2) if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Add>(new_args.at(0), new_args.at(1)); return std::make_shared<Add>(new_args.at(0), new_args.at(1));
} }
......
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#ifdef NGRAPH_DISTRIBUTED
#include "ngraph/ops/allreduce.hpp"
using namespace std;
using namespace ngraph;
op::AllReduce::AllReduce(const std::shared_ptr<Node>& arg)
: RequiresTensorViewArgs("AllReduce", {arg})
{
auto& input = m_inputs.at(0);
set_value_type_checked(
make_shared<TensorViewType>(input.get_element_type(), input.get_shape()));
if ((arg->get_element_type() != element::f32) && (arg->get_element_type() != element::f64))
{
throw ngraph_error("Unsupported data type for AllReduce");
}
}
#endif
// ----------------------------------------------------------------------------
// Copyright 2017 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#ifdef NGRAPH_DISTRIBUTED
#include <memory>
#include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph
{
namespace op
{
class AllReduce : public util::RequiresTensorViewArgs
{
public:
AllReduce(const std::shared_ptr<Node>& arg);
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override
{
if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<AllReduce>(new_args.at(0));
}
};
}
}
#endif
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include <memory> #include <memory>
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -26,7 +26,7 @@ namespace ngraph ...@@ -26,7 +26,7 @@ namespace ngraph
{ {
/// \brief Elementwise inverse sine (arcsin) operation. /// \brief Elementwise inverse sine (arcsin) operation.
/// ///
class Asin : public UnaryElementwiseArithmetic class Asin : public util::UnaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs an arcsin operation. /// \brief Constructs an arcsin operation.
...@@ -45,7 +45,9 @@ namespace ngraph ...@@ -45,7 +45,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 1) if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Asin>(new_args.at(0)); return std::make_shared<Asin>(new_args.at(0));
} }
}; };
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include <memory> #include <memory>
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -26,7 +26,7 @@ namespace ngraph ...@@ -26,7 +26,7 @@ namespace ngraph
{ {
/// \brief Elementwise inverse tangent (arctan) operation. /// \brief Elementwise inverse tangent (arctan) operation.
/// ///
class Atan : public UnaryElementwiseArithmetic class Atan : public util::UnaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs an arctan operation. /// \brief Constructs an arctan operation.
...@@ -45,7 +45,9 @@ namespace ngraph ...@@ -45,7 +45,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 1) if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Atan>(new_args.at(0)); return std::make_shared<Atan>(new_args.at(0));
} }
}; };
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -24,7 +24,7 @@ namespace ngraph ...@@ -24,7 +24,7 @@ namespace ngraph
{ {
/// \brief Batched average pooling operation, with optional padding and window stride. /// \brief Batched average pooling operation, with optional padding and window stride.
/// ///
class AvgPool : public RequiresTensorViewArgs class AvgPool : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a batched average pooling operation. /// \brief Constructs a batched average pooling operation.
...@@ -69,7 +69,9 @@ namespace ngraph ...@@ -69,7 +69,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 1) if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<AvgPool>(new_args.at(0), return std::make_shared<AvgPool>(new_args.at(0),
m_window_shape, m_window_shape,
...@@ -96,7 +98,7 @@ namespace ngraph ...@@ -96,7 +98,7 @@ namespace ngraph
Shape m_padding_above; Shape m_padding_above;
}; };
class AvgPoolBackprop : public RequiresTensorViewArgs class AvgPoolBackprop : public util::RequiresTensorViewArgs
{ {
public: public:
AvgPoolBackprop(const Shape& forward_arg_shape, AvgPoolBackprop(const Shape& forward_arg_shape,
...@@ -110,7 +112,9 @@ namespace ngraph ...@@ -110,7 +112,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 1) if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
AvgPoolBackprop* avpn = new AvgPoolBackprop(m_forward_arg_shape, AvgPoolBackprop* avpn = new AvgPoolBackprop(m_forward_arg_shape,
new_args.at(0), new_args.at(0),
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/ops/batch_norm.hpp"
#include "ngraph/ops/constant.hpp"
ngraph::op::BatchNorm::BatchNorm(double eps,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance)
: RequiresTensorViewArgs("BatchNorm", {gamma, beta, input, mean, variance})
, m_bn_input_shape(input->get_shape())
, m_bn_variance_shape(variance->get_shape())
, m_bn_mean_shape(mean->get_shape())
, m_epsilon(eps)
{
add_output(input->get_element_type(), m_bn_input_shape);
if (m_bn_input_shape.size() < 2)
{
throw ngraph_error("input tensor to batchnorm much have tensor of atleast rank 2");
}
if (m_bn_input_shape[1] == 0)
{
throw ngraph_error(
"input tensor must have atleast one channel axis for batch normalization");
}
if ((m_bn_mean_shape.size() != 1) && (m_bn_variance_shape.size() != 1) &&
(gamma->get_shape().size() != 1) && (beta->get_shape().size() != 1))
{
throw ngraph_error("gamma, beta, mean, variance shoud have all rank 1");
}
// assuming input shape (N, C, H, W), check if the size of mean and
// variance are equal to channel axis
if (mean->get_shape()[0] != m_bn_input_shape[1])
{
throw ngraph_error("mean size is not equal to input channel size");
}
if (variance->get_shape()[0] != m_bn_input_shape[1])
{
throw ngraph_error("variance size is not equal to input channel size");
}
if (variance->get_shape().size() != mean->get_shape().size())
{
throw ngraph_error("mean and variance rank does not match");
}
if (gamma->get_shape().size() != beta->get_shape().size())
{
throw ngraph_error("gamma and beta rank does not match");
}
if (input->get_element_type() != mean->get_element_type())
{
throw ngraph_error("input tensor and mean element type does not match");
}
if (input->get_element_type() != variance->get_element_type())
{
throw ngraph_error("input tensor and variance element type does not match");
}
if (gamma->get_element_type() != beta->get_element_type())
{
throw ngraph_error("gamma and beta element type does not match");
}
}
std::shared_ptr<ngraph::Node> ngraph::op::BatchNorm::copy_with_new_args(
const std::vector<std::shared_ptr<ngraph::Node>>& new_args) const
{
if (new_args.size() != 5)
throw ngraph_error("Incorrect number of new arguments");
return std::make_shared<BatchNorm>(
m_epsilon, new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), new_args.at(4));
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include <memory>
#include "ngraph/node.hpp"
#include "ngraph/ops/util/requires_tensor_view_args.hpp"
#include "ngraph/util.hpp"
namespace ngraph
{
namespace op
{
class BatchNorm : public util::RequiresTensorViewArgs
{
public:
BatchNorm(double eps,
std::shared_ptr<Node> gamma,
std::shared_ptr<Node> beta,
std::shared_ptr<Node> input,
std::shared_ptr<Node> mean,
std::shared_ptr<Node> variance);
const Shape& get_inputs_shape() const { return m_bn_input_shape; }
const Shape& get_variance_shape() const { return m_bn_variance_shape; }
const Shape& get_mean_shape() const { return m_bn_mean_shape; }
double get_eps_value() const { return m_epsilon; }
virtual std::shared_ptr<Node> copy_with_new_args(
const std::vector<std::shared_ptr<Node>>& new_args) const override;
private:
Shape m_bn_input_shape;
Shape m_bn_variance_shape;
Shape m_bn_mean_shape;
double m_epsilon;
};
}
}
...@@ -16,14 +16,14 @@ ...@@ -16,14 +16,14 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
{ {
/// \brief Operation which "adds" axes to an input tensor, replicating elements from the input as needed along the new axes. /// \brief Operation which "adds" axes to an input tensor, replicating elements from the input as needed along the new axes.
class Broadcast : public RequiresTensorViewArgs class Broadcast : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a conversion operation. /// \brief Constructs a conversion operation.
...@@ -40,7 +40,9 @@ namespace ngraph ...@@ -40,7 +40,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 1) if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Broadcast>(new_args.at(0), m_shape, m_broadcast_axes); return std::make_shared<Broadcast>(new_args.at(0), m_shape, m_broadcast_axes);
} }
......
...@@ -16,14 +16,14 @@ ...@@ -16,14 +16,14 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
{ {
/// \brief Elementwise ceiling operation. /// \brief Elementwise ceiling operation.
class Ceiling : public UnaryElementwiseArithmetic class Ceiling : public util::UnaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs a ceiling operation. /// \brief Constructs a ceiling operation.
...@@ -38,7 +38,9 @@ namespace ngraph ...@@ -38,7 +38,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 1) if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Ceiling>(new_args.at(0)); return std::make_shared<Ceiling>(new_args.at(0));
} }
}; };
......
...@@ -18,14 +18,14 @@ ...@@ -18,14 +18,14 @@
#include <memory> #include <memory>
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
{ {
/// \brief Concatenation operation. /// \brief Concatenation operation.
class Concat : public RequiresTensorViewArgs class Concat : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a concatenation operation. /// \brief Constructs a concatenation operation.
......
...@@ -29,20 +29,6 @@ namespace ngraph ...@@ -29,20 +29,6 @@ namespace ngraph
namespace op namespace op
{ {
/// \brief Class for constants. /// \brief Class for constants.
///
/// ## Parameters
///
/// | | Description |
/// | --------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
/// | `type` | The ngraph::element::Type of the tensor constant. |
/// | `shape` | The ngraph::Shape of the tensor constant. |
/// | `values` | A list of values to initialize the underlying tensor constant. |
///
/// ## Output
///
/// | Type | Description |
/// | ---------------------- | --------------------------------------------------------------------- |
/// | \f$E[d_1,\dots,d_n]\f$ | A constant tensor with the specified element type, shape, and values. |
class Constant : public Node class Constant : public Node
{ {
public: public:
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise.hpp"
#include "ngraph/types/type.hpp" #include "ngraph/types/type.hpp"
namespace ngraph namespace ngraph
...@@ -24,28 +24,7 @@ namespace ngraph ...@@ -24,28 +24,7 @@ namespace ngraph
namespace op namespace op
{ {
/// \brief Elementwise type conversion operation. /// \brief Elementwise type conversion operation.
/// class Convert : public util::UnaryElementwise
/// Each scalar in the input tensor is converted to the specified output element type. Note that the conversion may
/// result in loss of precision. For example, conversion from `float32` to `int32` is allowed.
///
/// ## Parameters
///
/// | | Description |
/// | -------------- | ---------------------------------------- |
/// | `element_type` | The element type \f$E'\f$ to convert to. |
///
/// ## Inputs
///
/// | | Type | Description |
/// | ----- | --------------------------------- | ----------------------------------------- |
/// | `arg` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape and element type. |
///
/// ## Output
///
/// | Type | Description |
/// | ----------------------- | --------------------------------------------------------------------------------------------------------- |
/// | \f$E'[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{convert}_{(E,E')}(\texttt{arg}[i_1,\dots,i_n])\f$ |
class Convert : public UnaryElementwise
{ {
public: public:
/// \brief Constructs a conversion operation. /// \brief Constructs a conversion operation.
...@@ -58,7 +37,9 @@ namespace ngraph ...@@ -58,7 +37,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 1) if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Convert>(new_args.at(0), m_element_type); return std::make_shared<Convert>(new_args.at(0), m_element_type);
} }
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -24,7 +24,7 @@ namespace ngraph ...@@ -24,7 +24,7 @@ namespace ngraph
{ {
/// \brief Batched convolution operation, with optional window dilation and stride. /// \brief Batched convolution operation, with optional window dilation and stride.
/// ///
class Convolution : public RequiresTensorViewArgs class Convolution : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a batched convolution operation. /// \brief Constructs a batched convolution operation.
...@@ -151,7 +151,7 @@ namespace ngraph ...@@ -151,7 +151,7 @@ namespace ngraph
}; };
/// \brief Data batch backprop for batched convolution operation. /// \brief Data batch backprop for batched convolution operation.
class ConvolutionBackpropData : public RequiresTensorViewArgs class ConvolutionBackpropData : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a batched-convolution data batch-backprop operation. /// \brief Constructs a batched-convolution data batch-backprop operation.
...@@ -246,7 +246,7 @@ namespace ngraph ...@@ -246,7 +246,7 @@ namespace ngraph
}; };
/// \brief Filters backprop for batched convolution operation. /// \brief Filters backprop for batched convolution operation.
class ConvolutionBackpropFilters : public RequiresTensorViewArgs class ConvolutionBackpropFilters : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a batched-convolution filter-backprop operation. /// \brief Constructs a batched-convolution filter-backprop operation.
......
...@@ -16,26 +16,14 @@ ...@@ -16,26 +16,14 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
{ {
/// \brief Elementwise cosine operation. /// \brief Elementwise cosine operation.
/// class Cos : public util::UnaryElementwiseArithmetic
/// ## Inputs
///
/// | | Type | Description |
/// | ----- | --------------------------------- | ----------------------------------------------- |
/// | `arg` | \f$N[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape and numeric element type. |
///
/// ## Output
///
/// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------ |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \cos(\texttt{arg}[i_1,\dots,i_n])\f$ |
class Cos : public UnaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs a cosine operation. /// \brief Constructs a cosine operation.
...@@ -50,7 +38,9 @@ namespace ngraph ...@@ -50,7 +38,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 1) if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Cos>(new_args.at(0)); return std::make_shared<Cos>(new_args.at(0));
} }
......
...@@ -16,26 +16,14 @@ ...@@ -16,26 +16,14 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
{ {
/// \brief Elementwise hyperbolic cosine (cosh) operation. /// \brief Elementwise hyperbolic cosine (cosh) operation.
/// class Cosh : public util::UnaryElementwiseArithmetic
/// ## Inputs
///
/// | | Type | Description |
/// | ----- | --------------------------------- | ----------------------------------------------- |
/// | `arg` | \f$N[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape and numeric element type. |
///
/// ## Output
///
/// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------- |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \cosh(\texttt{arg}[i_1,\dots,i_n])\f$ |
class Cosh : public UnaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs a hyperbolic cosine operation. /// \brief Constructs a hyperbolic cosine operation.
...@@ -50,7 +38,9 @@ namespace ngraph ...@@ -50,7 +38,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 1) if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Cosh>(new_args.at(0)); return std::make_shared<Cosh>(new_args.at(0));
} }
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/binary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------------------------------------------ | /// | ---------------------- | ------------------------------------------------------------------------------------------------------------------------ |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg0}[i_1,\dots,i_n] \mathbin{/} \texttt{arg1}[i_1,\dots,i_n]\f$ | /// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg0}[i_1,\dots,i_n] \mathbin{/} \texttt{arg1}[i_1,\dots,i_n]\f$ |
class Divide : public BinaryElementwiseArithmetic class Divide : public util::BinaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs a division operation. /// \brief Constructs a division operation.
...@@ -52,7 +52,9 @@ namespace ngraph ...@@ -52,7 +52,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 2) if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Divide>(new_args.at(0), new_args.at(1)); return std::make_shared<Divide>(new_args.at(0), new_args.at(1));
} }
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include <utility> #include <utility>
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -54,7 +54,7 @@ namespace ngraph ...@@ -54,7 +54,7 @@ namespace ngraph
/// | ---------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | /// | ---------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$E[d_1,\dots,d_n,d''_1,\dots,d''_p]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n,k_1,\dots,k_p] = \Sigma_{0 \le j_1 < d'_1, \dots, 0 \le j_m < d'_m}(\mathtt{arg0}[i_1,\dots,i_n,j_1,\dots,j_m] \cdot \mathtt{arg1}[j_1,\dots,j_m,k_1,\dots,k_p])\f$ or, if \f$m = 0\f$, \f$T[i_1,\dots,i_n,k_1,\dots,k_p] = \mathtt{arg0}[i_1,\dots,i_n] \cdot \mathtt{arg1}[k_1,\dots,k_p]\f$. | /// | \f$E[d_1,\dots,d_n,d''_1,\dots,d''_p]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n,k_1,\dots,k_p] = \Sigma_{0 \le j_1 < d'_1, \dots, 0 \le j_m < d'_m}(\mathtt{arg0}[i_1,\dots,i_n,j_1,\dots,j_m] \cdot \mathtt{arg1}[j_1,\dots,j_m,k_1,\dots,k_p])\f$ or, if \f$m = 0\f$, \f$T[i_1,\dots,i_n,k_1,\dots,k_p] = \mathtt{arg0}[i_1,\dots,i_n] \cdot \mathtt{arg1}[k_1,\dots,k_p]\f$. |
/// ///
class Dot : public RequiresTensorViewArgs class Dot : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a dot product operation. /// \brief Constructs a dot product operation.
...@@ -83,7 +83,9 @@ namespace ngraph ...@@ -83,7 +83,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 2) if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Dot>( return std::make_shared<Dot>(
new_args.at(0), new_args.at(1), m_reduction_axes_count); new_args.at(0), new_args.at(1), m_reduction_axes_count);
} }
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/binary_elementwise_comparison.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------ | /// | ---------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------ |
/// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg0}[i_1,\dots,i_n] = \texttt{arg1}[i_1,\dots,i_n]\text{, else } 0\f$ | /// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg0}[i_1,\dots,i_n] = \texttt{arg1}[i_1,\dots,i_n]\text{, else } 0\f$ |
class Equal : public BinaryElementwiseComparison class Equal : public util::BinaryElementwiseComparison
{ {
public: public:
/// \brief Constructs an is-equal operation. /// \brief Constructs an is-equal operation.
...@@ -52,7 +52,9 @@ namespace ngraph ...@@ -52,7 +52,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 2) if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Equal>(new_args.at(0), new_args.at(1)); return std::make_shared<Equal>(new_args.at(0), new_args.at(1));
} }
}; };
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -35,7 +35,7 @@ namespace ngraph ...@@ -35,7 +35,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------ | /// | ---------------------- | ------------------------------------------------------------------------------------ |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \exp(\texttt{arg}[i_1,\dots,i_n])\f$ | /// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \exp(\texttt{arg}[i_1,\dots,i_n])\f$ |
class Exp : public UnaryElementwiseArithmetic class Exp : public util::UnaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs an exponential operation. /// \brief Constructs an exponential operation.
...@@ -50,7 +50,9 @@ namespace ngraph ...@@ -50,7 +50,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 1) if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Exp>(new_args.at(0)); return std::make_shared<Exp>(new_args.at(0));
} }
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -35,7 +35,7 @@ namespace ngraph ...@@ -35,7 +35,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | ---------------------------------------------------------------------------------------------- | /// | ---------------------- | ---------------------------------------------------------------------------------------------- |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \lfloor \texttt{arg}[i_1,\dots,i_n] \rfloor\f$ | /// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \lfloor \texttt{arg}[i_1,\dots,i_n] \rfloor\f$ |
class Floor : public UnaryElementwiseArithmetic class Floor : public util::UnaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs a floor operation. /// \brief Constructs a floor operation.
...@@ -50,7 +50,9 @@ namespace ngraph ...@@ -50,7 +50,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 1) if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Floor>(new_args.at(0)); return std::make_shared<Floor>(new_args.at(0));
} }
}; };
......
...@@ -54,7 +54,9 @@ namespace ngraph ...@@ -54,7 +54,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 1) if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<GetOutputElement>(new_args.at(0), m_n); return std::make_shared<GetOutputElement>(new_args.at(0), m_n);
} }
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/binary_elementwise_comparison.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------- | /// | ---------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg0}[i_1,\dots,i_n] \gt \texttt{arg1}[i_1,\dots,i_n]\text{, else } 0\f$ | /// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg0}[i_1,\dots,i_n] \gt \texttt{arg1}[i_1,\dots,i_n]\text{, else } 0\f$ |
class Greater : public BinaryElementwiseComparison class Greater : public util::BinaryElementwiseComparison
{ {
public: public:
/// \brief Constructs a greater-than operation. /// \brief Constructs a greater-than operation.
...@@ -52,7 +52,9 @@ namespace ngraph ...@@ -52,7 +52,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 2) if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Greater>(new_args.at(0), new_args.at(1)); return std::make_shared<Greater>(new_args.at(0), new_args.at(1));
} }
}; };
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/binary_elementwise_comparison.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------- | /// | ---------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg0}[i_1,\dots,i_n] \geq \texttt{arg1}[i_1,\dots,i_n]\text{, else } 0\f$ | /// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg0}[i_1,\dots,i_n] \geq \texttt{arg1}[i_1,\dots,i_n]\text{, else } 0\f$ |
class GreaterEq : public BinaryElementwiseComparison class GreaterEq : public util::BinaryElementwiseComparison
{ {
public: public:
/// \brief Constructs a greater-than-or-equal operation. /// \brief Constructs a greater-than-or-equal operation.
...@@ -52,7 +52,9 @@ namespace ngraph ...@@ -52,7 +52,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 2) if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<GreaterEq>(new_args.at(0), new_args.at(1)); return std::make_shared<GreaterEq>(new_args.at(0), new_args.at(1));
} }
}; };
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/binary_elementwise_comparison.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------- | /// | ---------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg0}[i_1,\dots,i_n] \lt \texttt{arg1}[i_1,\dots,i_n]\text{, else } 0\f$ | /// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg0}[i_1,\dots,i_n] \lt \texttt{arg1}[i_1,\dots,i_n]\text{, else } 0\f$ |
class Less : public BinaryElementwiseComparison class Less : public util::BinaryElementwiseComparison
{ {
public: public:
/// \brief Constructs a less-than operation. /// \brief Constructs a less-than operation.
...@@ -52,7 +52,9 @@ namespace ngraph ...@@ -52,7 +52,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 2) if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Less>(new_args.at(0), new_args.at(1)); return std::make_shared<Less>(new_args.at(0), new_args.at(1));
} }
}; };
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/binary_elementwise_comparison.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------- | /// | ---------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg0}[i_1,\dots,i_n] \leq \texttt{arg1}[i_1,\dots,i_n]\text{, else } 0\f$ | /// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg0}[i_1,\dots,i_n] \leq \texttt{arg1}[i_1,\dots,i_n]\text{, else } 0\f$ |
class LessEq : public BinaryElementwiseComparison class LessEq : public util::BinaryElementwiseComparison
{ {
public: public:
/// \brief Constructs a less-than-or-equal operation. /// \brief Constructs a less-than-or-equal operation.
...@@ -52,7 +52,9 @@ namespace ngraph ...@@ -52,7 +52,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 2) if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<LessEq>(new_args.at(0), new_args.at(1)); return std::make_shared<LessEq>(new_args.at(0), new_args.at(1));
} }
}; };
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -35,7 +35,7 @@ namespace ngraph ...@@ -35,7 +35,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | ----------------------------------------------------------------------------------- | /// | ---------------------- | ----------------------------------------------------------------------------------- |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \ln(\texttt{arg}[i_1,\dots,i_n])\f$ | /// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \ln(\texttt{arg}[i_1,\dots,i_n])\f$ |
class Log : public UnaryElementwiseArithmetic class Log : public util::UnaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs a natural log operation. /// \brief Constructs a natural log operation.
...@@ -50,7 +50,9 @@ namespace ngraph ...@@ -50,7 +50,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 1) if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Log>(new_args.at(0)); return std::make_shared<Log>(new_args.at(0));
} }
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -41,7 +41,7 @@ namespace ngraph ...@@ -41,7 +41,7 @@ namespace ngraph
/// T_\textit{out}[a,c,i_1,\dots,i_n] = \max_{j_1 = s_1 i_1, \dots, j_n = s_n i_n}^{j_1 = s_1 i_1 + w_1 - 1, \dots, j_n = s_n i_n + w_n - 1} (T_\textit{in}[a,c,j_1,\dots,j_n]) /// T_\textit{out}[a,c,i_1,\dots,i_n] = \max_{j_1 = s_1 i_1, \dots, j_n = s_n i_n}^{j_1 = s_1 i_1 + w_1 - 1, \dots, j_n = s_n i_n + w_n - 1} (T_\textit{in}[a,c,j_1,\dots,j_n])
/// \f] /// \f]
/// ///
class MaxPool : public RequiresTensorViewArgs class MaxPool : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a batched max pooling operation. /// \brief Constructs a batched max pooling operation.
...@@ -76,7 +76,9 @@ namespace ngraph ...@@ -76,7 +76,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 1) if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<MaxPool>( return std::make_shared<MaxPool>(
new_args.at(0), m_window_shape, m_window_movement_strides); new_args.at(0), m_window_shape, m_window_movement_strides);
} }
...@@ -99,7 +101,7 @@ namespace ngraph ...@@ -99,7 +101,7 @@ namespace ngraph
Shape m_padding_above; Shape m_padding_above;
}; };
class MaxPoolBackprop : public RequiresTensorViewArgs class MaxPoolBackprop : public util::RequiresTensorViewArgs
{ {
public: public:
MaxPoolBackprop(const std::shared_ptr<Node>& arg_forward, MaxPoolBackprop(const std::shared_ptr<Node>& arg_forward,
...@@ -114,7 +116,9 @@ namespace ngraph ...@@ -114,7 +116,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 2) if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
MaxPoolBackprop* mpbp = new MaxPoolBackprop(new_args.at(0), MaxPoolBackprop* mpbp = new MaxPoolBackprop(new_args.at(0),
new_args.at(1), new_args.at(1),
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/binary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------------------------------------ | /// | ---------------------- | ------------------------------------------------------------------------------------------------------------------ |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \max(\texttt{arg0}[i_1,\dots,i_n],\texttt{arg1}[i_1,\dots,i_n])\f$ | /// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \max(\texttt{arg0}[i_1,\dots,i_n],\texttt{arg1}[i_1,\dots,i_n])\f$ |
class Maximum : public BinaryElementwiseArithmetic class Maximum : public util::BinaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs a maximum operation. /// \brief Constructs a maximum operation.
...@@ -52,7 +52,9 @@ namespace ngraph ...@@ -52,7 +52,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 2) if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Maximum>(new_args.at(0), new_args.at(1)); return std::make_shared<Maximum>(new_args.at(0), new_args.at(1));
} }
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/binary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------------------------------------ | /// | ---------------------- | ------------------------------------------------------------------------------------------------------------------ |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \min(\texttt{arg0}[i_1,\dots,i_n],\texttt{arg1}[i_1,\dots,i_n])\f$ | /// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \min(\texttt{arg0}[i_1,\dots,i_n],\texttt{arg1}[i_1,\dots,i_n])\f$ |
class Minimum : public BinaryElementwiseArithmetic class Minimum : public util::BinaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs a minimum operation. /// \brief Constructs a minimum operation.
...@@ -52,7 +52,9 @@ namespace ngraph ...@@ -52,7 +52,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 2) if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Minimum>(new_args.at(0), new_args.at(1)); return std::make_shared<Minimum>(new_args.at(0), new_args.at(1));
} }
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/binary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------------------------------------ | /// | ---------------------- | ------------------------------------------------------------------------------------------------------------------ |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg0}[i_1,\dots,i_n] \cdot \texttt{arg1}[i_1,\dots,i_n]\f$ | /// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg0}[i_1,\dots,i_n] \cdot \texttt{arg1}[i_1,\dots,i_n]\f$ |
class Multiply : public BinaryElementwiseArithmetic class Multiply : public util::BinaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs a multiplication operation. /// \brief Constructs a multiplication operation.
...@@ -52,7 +52,9 @@ namespace ngraph ...@@ -52,7 +52,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 2) if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Multiply>(new_args.at(0), new_args.at(1)); return std::make_shared<Multiply>(new_args.at(0), new_args.at(1));
} }
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -35,7 +35,7 @@ namespace ngraph ...@@ -35,7 +35,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | --------------------------------------------------------------------------------- | /// | ---------------------- | --------------------------------------------------------------------------------- |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = -(\texttt{arg}[i_1,\dots,i_n])\f$ | /// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = -(\texttt{arg}[i_1,\dots,i_n])\f$ |
class Negative : public UnaryElementwiseArithmetic class Negative : public util::UnaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs a negation operation. /// \brief Constructs a negation operation.
...@@ -50,7 +50,9 @@ namespace ngraph ...@@ -50,7 +50,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 1) if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Negative>(new_args.at(0)); return std::make_shared<Negative>(new_args.at(0));
} }
......
...@@ -21,6 +21,6 @@ using namespace ngraph; ...@@ -21,6 +21,6 @@ using namespace ngraph;
using namespace std; using namespace std;
op::Not::Not(const shared_ptr<Node>& arg) op::Not::Not(const shared_ptr<Node>& arg)
: op::UnaryElementwise("Not", arg->get_element_type(), arg) : UnaryElementwise("Not", arg->get_element_type(), arg)
{ {
} }
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -35,7 +35,7 @@ namespace ngraph ...@@ -35,7 +35,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------------------- | -------------------------------------------------------------------------------------------------------------- | /// | ---------------------------------- | -------------------------------------------------------------------------------------------------------------- |
/// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg}[i_1,\dots,i_n] = 0\text{, else } 0\f$ | /// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg}[i_1,\dots,i_n] = 0\text{, else } 0\f$ |
class Not : public UnaryElementwise class Not : public util::UnaryElementwise
{ {
public: public:
/// \brief Constructs a logical negation operation. /// \brief Constructs a logical negation operation.
...@@ -47,7 +47,9 @@ namespace ngraph ...@@ -47,7 +47,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 1) if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Not>(new_args.at(0)); return std::make_shared<Not>(new_args.at(0));
} }
}; };
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/binary_elementwise_comparison.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------- | /// | ---------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg0}[i_1,\dots,i_n] \neq \texttt{arg1}[i_1,\dots,i_n]\text{, else } 0\f$ | /// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg0}[i_1,\dots,i_n] \neq \texttt{arg1}[i_1,\dots,i_n]\text{, else } 0\f$ |
class NotEqual : public BinaryElementwiseComparison class NotEqual : public util::BinaryElementwiseComparison
{ {
public: public:
/// \brief Constructs a not-equal operation. /// \brief Constructs a not-equal operation.
...@@ -52,7 +52,9 @@ namespace ngraph ...@@ -52,7 +52,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 2) if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<NotEqual>(new_args.at(0), new_args.at(1)); return std::make_shared<NotEqual>(new_args.at(0), new_args.at(1));
} }
}; };
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -42,7 +42,7 @@ namespace ngraph ...@@ -42,7 +42,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | /// | ---------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$E[d_1,\dots,d_n]\f$ | The tensor \f$T'\f$, where \f$T'[i_1,\dots,i_{m-1},i_m,i_{m+1},\dots,i_n] = 1\f$ if \f$T[i_1,\dots,i_{m-1},i_{m+1},\dots,i_n] = i_m\f$, else \f$0\f$. However, \f$T'\f$ is undefined if any non-integral value or any out-of-bounds value is detected in the input tensor. | /// | \f$E[d_1,\dots,d_n]\f$ | The tensor \f$T'\f$, where \f$T'[i_1,\dots,i_{m-1},i_m,i_{m+1},\dots,i_n] = 1\f$ if \f$T[i_1,\dots,i_{m-1},i_{m+1},\dots,i_n] = i_m\f$, else \f$0\f$. However, \f$T'\f$ is undefined if any non-integral value or any out-of-bounds value is detected in the input tensor. |
class OneHot : public RequiresTensorViewArgs class OneHot : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a one-hot operation. /// \brief Constructs a one-hot operation.
...@@ -56,7 +56,9 @@ namespace ngraph ...@@ -56,7 +56,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 1) if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<OneHot>(new_args.at(0), m_shape, m_one_hot_axis); return std::make_shared<OneHot>(new_args.at(0), m_shape, m_one_hot_axis);
} }
......
...@@ -18,23 +18,14 @@ ...@@ -18,23 +18,14 @@
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include "ngraph/except.hpp" #include "ngraph/common.hpp"
#include "ngraph/ops/op.hpp" #include "ngraph/ops/op.hpp"
#include "ngraph/types/type.hpp" #include "ngraph/types/type.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
op::RequiresTensorViewArgs::RequiresTensorViewArgs(const std::string& node_type, op::Op::Op(const std::string& node_type, const Nodes& args)
const std::vector<std::shared_ptr<Node>>& args)
: Node(node_type, args) : Node(node_type, args)
{ {
for (auto arg : args)
{
if (arg->get_output_size() != 1)
{
throw ngraph_error("Arguments for node type \"" + node_type +
"\" must be tensor views");
}
}
} }
This diff is collapsed.
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -56,7 +56,7 @@ namespace ngraph ...@@ -56,7 +56,7 @@ namespace ngraph
/// (Note that `below` and `above` here refer respectively to lower- or higher-numbered coordinate indices, and numbering starts at the upper-left corner; /// (Note that `below` and `above` here refer respectively to lower- or higher-numbered coordinate indices, and numbering starts at the upper-left corner;
/// thus inserting a row "below" actually inserts it at the "top" of the matrix.) /// thus inserting a row "below" actually inserts it at the "top" of the matrix.)
/// ///
class Pad : public RequiresTensorViewArgs class Pad : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a generic padding operation. /// \brief Constructs a generic padding operation.
......
...@@ -22,7 +22,7 @@ using namespace std; ...@@ -22,7 +22,7 @@ using namespace std;
using namespace ngraph; using namespace ngraph;
op::Parameter::Parameter(const ngraph::element::Type& element_type, const Shape& shape) op::Parameter::Parameter(const ngraph::element::Type& element_type, const Shape& shape)
: Node("Parameter", {}) : Op("Parameter", {})
{ {
add_output(element_type, shape); add_output(element_type, shape);
} }
...@@ -30,7 +30,9 @@ op::Parameter::Parameter(const ngraph::element::Type& element_type, const Shape& ...@@ -30,7 +30,9 @@ op::Parameter::Parameter(const ngraph::element::Type& element_type, const Shape&
shared_ptr<Node> op::Parameter::copy_with_new_args(const vector<shared_ptr<Node>>& new_args) const shared_ptr<Node> op::Parameter::copy_with_new_args(const vector<shared_ptr<Node>>& new_args) const
{ {
if (new_args.size() != 0) if (new_args.size() != 0)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
const descriptor::Output& output = get_outputs().at(0); const descriptor::Output& output = get_outputs().at(0);
return make_shared<Parameter>(output.get_element_type(), output.get_shape()); return make_shared<Parameter>(output.get_element_type(), output.get_shape());
} }
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/node.hpp" #include "ngraph/ops/op.hpp"
#include "ngraph/types/type.hpp" #include "ngraph/types/type.hpp"
namespace ngraph namespace ngraph
...@@ -41,7 +41,7 @@ namespace ngraph ...@@ -41,7 +41,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ------- | --------------------------------------------------------------------------------------------------------------------------- | /// | ------- | --------------------------------------------------------------------------------------------------------------------------- |
/// | \f$T\f$ | The value of the parameter, supplied by the `FunctionCall` to this function or in the initial `ngraph::runtime::CallFrame`. | /// | \f$T\f$ | The value of the parameter, supplied by the `FunctionCall` to this function or in the initial `ngraph::runtime::CallFrame`. |
class Parameter : public Node class Parameter : public op::Op
{ {
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/binary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | -------------------------------------------------------------------------------------------------------------- | /// | ---------------------- | -------------------------------------------------------------------------------------------------------------- |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg0}[i_1,\dots,i_n]^{\texttt{arg1}[i_1,\dots,i_n]}\f$ | /// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg0}[i_1,\dots,i_n]^{\texttt{arg1}[i_1,\dots,i_n]}\f$ |
class Power : public BinaryElementwiseArithmetic class Power : public util::BinaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs an exponentiation operation. /// \brief Constructs an exponentiation operation.
...@@ -52,7 +52,9 @@ namespace ngraph ...@@ -52,7 +52,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 2) if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Power>(new_args.at(0), new_args.at(1)); return std::make_shared<Power>(new_args.at(0), new_args.at(1));
} }
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -82,7 +82,7 @@ namespace ngraph ...@@ -82,7 +82,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ----------------------------------------- | ---------------------------------------------------------------------------------------------------------------- | /// | ----------------------------------------- | ---------------------------------------------------------------------------------------------------------------- |
/// | \f$E[\textit{delete}(A,d_1,\dots,d_n)]\f$ | The tensor \f$T\f$, where \f$T\f$ is the input tensor with the `reduction_axes` \f$A\f$ eliminated by reduction. | /// | \f$E[\textit{delete}(A,d_1,\dots,d_n)]\f$ | The tensor \f$T\f$, where \f$T\f$ is the input tensor with the `reduction_axes` \f$A\f$ eliminated by reduction. |
class Reduce : public RequiresTensorViewArgs class Reduce : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a reduction operation. /// \brief Constructs a reduction operation.
...@@ -100,7 +100,9 @@ namespace ngraph ...@@ -100,7 +100,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 2) if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Reduce>( return std::make_shared<Reduce>(
new_args.at(0), new_args.at(1), m_reduction_function, m_reduction_axes); new_args.at(0), new_args.at(1), m_reduction_function, m_reduction_axes);
} }
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -50,7 +50,7 @@ namespace ngraph ...@@ -50,7 +50,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | /// | ------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$E[d'_1,\dots,d'_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{reduce}(\mathit{reduction\_function},\mathit{arg\_init},V)\f$ where \f$V\f$ is the set of values in the input tensor within the window defined by the lower bound \f$(s_1i_1,\dots,s_ni_n)\f$ and the noninclusive upper bound \f$(s_1i_1 + w_1,\dots,s_ni_n + w_n)\f$. | /// | \f$E[d'_1,\dots,d'_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{reduce}(\mathit{reduction\_function},\mathit{arg\_init},V)\f$ where \f$V\f$ is the set of values in the input tensor within the window defined by the lower bound \f$(s_1i_1,\dots,s_ni_n)\f$ and the noninclusive upper bound \f$(s_1i_1 + w_1,\dots,s_ni_n + w_n)\f$. |
class ReduceWindow : public RequiresTensorViewArgs class ReduceWindow : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a reduce-window operation. /// \brief Constructs a reduce-window operation.
...@@ -70,7 +70,9 @@ namespace ngraph ...@@ -70,7 +70,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 2) if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<ReduceWindow>(new_args.at(0), return std::make_shared<ReduceWindow>(new_args.at(0),
new_args.at(1), new_args.at(1),
m_reduction_function, m_reduction_function,
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/binary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -38,7 +38,7 @@ namespace ngraph ...@@ -38,7 +38,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | ----------------------------------------------------------------------------------------------------------------- | /// | ---------------------- | ----------------------------------------------------------------------------------------------------------------- |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg0}[i_1,\dots,i_n] \mod \texttt{arg1}[i_1,\dots,i_n]\f$ | /// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg0}[i_1,\dots,i_n] \mod \texttt{arg1}[i_1,\dots,i_n]\f$ |
class Remainder : public BinaryElementwiseArithmetic class Remainder : public util::BinaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs a remainder operation. /// \brief Constructs a remainder operation.
...@@ -54,7 +54,9 @@ namespace ngraph ...@@ -54,7 +54,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 2) if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Remainder>(new_args.at(0), new_args.at(1)); return std::make_shared<Remainder>(new_args.at(0), new_args.at(1));
} }
}; };
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -45,7 +45,7 @@ namespace ngraph ...@@ -45,7 +45,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | /// | ---------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$E[d_1,\dots,d_n]\f$ | The tensor \f$T\f$ where \f$T[i_1,\dots,i_n] = \texttt{arg1}[j_1,\dots,j_n]\f$ if \f$j_1,\dots,j_n\f$ is in bounds for `arg1` and for all \f$m\f$, \f$i_m = l_m + j_m s_m\f$, otherwise \f$\texttt{arg0}[i_1,\dots,i_n]\f$. | /// | \f$E[d_1,\dots,d_n]\f$ | The tensor \f$T\f$ where \f$T[i_1,\dots,i_n] = \texttt{arg1}[j_1,\dots,j_n]\f$ if \f$j_1,\dots,j_n\f$ is in bounds for `arg1` and for all \f$m\f$, \f$i_m = l_m + j_m s_m\f$, otherwise \f$\texttt{arg0}[i_1,\dots,i_n]\f$. |
class ReplaceSlice : public RequiresTensorViewArgs class ReplaceSlice : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a tensor slice replacement operation. /// \brief Constructs a tensor slice replacement operation.
...@@ -78,7 +78,9 @@ namespace ngraph ...@@ -78,7 +78,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 2) if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<ReplaceSlice>( return std::make_shared<ReplaceSlice>(
new_args.at(0), new_args.at(1), m_lower_bounds, m_upper_bounds, m_strides); new_args.at(0), new_args.at(1), m_lower_bounds, m_upper_bounds, m_strides);
} }
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -55,7 +55,7 @@ namespace ngraph ...@@ -55,7 +55,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ------------------------ | ------------------------------------------------------------------------------------------------------ | /// | ------------------------ | ------------------------------------------------------------------------------------------------------ |
/// | \f$E[d'_1,\dots,d'_m]\f$ | The tensor \f$T\f$, where \f$T\f$ is the input tensor with its elements rearranged as described above. | /// | \f$E[d'_1,\dots,d'_m]\f$ | The tensor \f$T\f$, where \f$T\f$ is the input tensor with its elements rearranged as described above. |
class Reshape : public RequiresTensorViewArgs class Reshape : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a reshape operation. /// \brief Constructs a reshape operation.
...@@ -73,7 +73,9 @@ namespace ngraph ...@@ -73,7 +73,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 1) if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Reshape>(new_args.at(0), m_input_order, m_output_shape); return std::make_shared<Reshape>(new_args.at(0), m_input_order, m_output_shape);
} }
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -43,7 +43,7 @@ namespace ngraph ...@@ -43,7 +43,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | /// | ---------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$E[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg}[j_1,\dots,j_n]\f$ and \f$j_k = d_k - i_k - 1\f$ if axis \f$k\f$ is in the reverse set; else \f$j_k = i_k\f$. | /// | \f$E[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg}[j_1,\dots,j_n]\f$ and \f$j_k = d_k - i_k - 1\f$ if axis \f$k\f$ is in the reverse set; else \f$j_k = i_k\f$. |
class Reverse : public RequiresTensorViewArgs class Reverse : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a reverse operation. /// \brief Constructs a reverse operation.
...@@ -56,7 +56,9 @@ namespace ngraph ...@@ -56,7 +56,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 1) if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Reverse>(new_args.at(0), m_reversed_axes); return std::make_shared<Reverse>(new_args.at(0), m_reversed_axes);
} }
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -37,7 +37,7 @@ namespace ngraph ...@@ -37,7 +37,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- | /// | ---------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$E[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg1}[i_1,\dots,i_n]\text{ if }\texttt{arg0}[i_1,\dots,i_n] \neq 0\text{, else }\texttt{arg2}[i_1,\dots,i_n]\f$ | /// | \f$E[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg1}[i_1,\dots,i_n]\text{ if }\texttt{arg0}[i_1,\dots,i_n] \neq 0\text{, else }\texttt{arg2}[i_1,\dots,i_n]\f$ |
class Select : public RequiresTensorViewArgs class Select : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a selection operation. /// \brief Constructs a selection operation.
...@@ -53,7 +53,9 @@ namespace ngraph ...@@ -53,7 +53,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 3) if (new_args.size() != 3)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Select>(new_args.at(0), new_args.at(1), new_args.at(2)); return std::make_shared<Select>(new_args.at(0), new_args.at(1), new_args.at(2));
} }
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -68,7 +68,7 @@ namespace ngraph ...@@ -68,7 +68,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | -------------------- | /// | ---------------------- | -------------------- |
/// | \f$E[d_1,\dots,d_n]\f$ | See above algorithm. | /// | \f$E[d_1,\dots,d_n]\f$ | See above algorithm. |
class SelectAndScatter : public RequiresTensorViewArgs class SelectAndScatter : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a select-and-scatter operation. /// \brief Constructs a select-and-scatter operation.
...@@ -91,7 +91,9 @@ namespace ngraph ...@@ -91,7 +91,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 3) if (new_args.size() != 3)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<SelectAndScatter>(new_args.at(0), return std::make_shared<SelectAndScatter>(new_args.at(0),
new_args.at(1), new_args.at(1),
new_args.at(2), new_args.at(2),
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -37,7 +37,7 @@ namespace ngraph ...@@ -37,7 +37,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------ | /// | ---------------------- | ------------------------------------------------------------------------------------ |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \text{sgn}(\texttt{arg}[i_1,\dots,i_n])\f$ | /// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \text{sgn}(\texttt{arg}[i_1,\dots,i_n])\f$ |
class Sign : public UnaryElementwiseArithmetic class Sign : public util::UnaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs an elementwise sign operation. /// \brief Constructs an elementwise sign operation.
...@@ -52,7 +52,9 @@ namespace ngraph ...@@ -52,7 +52,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 1) if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Sign>(new_args.at(0)); return std::make_shared<Sign>(new_args.at(0));
} }
}; };
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -35,7 +35,7 @@ namespace ngraph ...@@ -35,7 +35,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------ | /// | ---------------------- | ------------------------------------------------------------------------------------ |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \sin(\texttt{arg}[i_1,\dots,i_n])\f$ | /// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \sin(\texttt{arg}[i_1,\dots,i_n])\f$ |
class Sin : public UnaryElementwiseArithmetic class Sin : public util::UnaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs a sine operation. /// \brief Constructs a sine operation.
...@@ -50,7 +50,9 @@ namespace ngraph ...@@ -50,7 +50,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 1) if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Sin>(new_args.at(0)); return std::make_shared<Sin>(new_args.at(0));
} }
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -35,7 +35,7 @@ namespace ngraph ...@@ -35,7 +35,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------- | /// | ---------------------- | ------------------------------------------------------------------------------------- |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \sinh(\texttt{arg}[i_1,\dots,i_n])\f$ | /// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \sinh(\texttt{arg}[i_1,\dots,i_n])\f$ |
class Sinh : public UnaryElementwiseArithmetic class Sinh : public util::UnaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs a hyperbolic sine operation. /// \brief Constructs a hyperbolic sine operation.
...@@ -50,7 +50,9 @@ namespace ngraph ...@@ -50,7 +50,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 1) if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Sinh>(new_args.at(0)); return std::make_shared<Sinh>(new_args.at(0));
} }
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -48,7 +48,7 @@ namespace ngraph ...@@ -48,7 +48,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ------------------------------------------------------------------------------ | --------------------------------- | /// | ------------------------------------------------------------------------------ | --------------------------------- |
/// | \f$E[d'_1,\dots,d'_n]\f$ where \f$d'_i = \lceil(u_i - l_i)\, /\, s_i\rceil\f$. | The tensor sliced from the input. | /// | \f$E[d'_1,\dots,d'_n]\f$ where \f$d'_i = \lceil(u_i - l_i)\, /\, s_i\rceil\f$. | The tensor sliced from the input. |
class Slice : public RequiresTensorViewArgs class Slice : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a tensor slice operation. /// \brief Constructs a tensor slice operation.
...@@ -76,7 +76,9 @@ namespace ngraph ...@@ -76,7 +76,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 1) if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Slice>( return std::make_shared<Slice>(
new_args.at(0), m_lower_bounds, m_upper_bounds, m_strides); new_args.at(0), m_lower_bounds, m_upper_bounds, m_strides);
} }
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -35,7 +35,7 @@ namespace ngraph ...@@ -35,7 +35,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------- | /// | ---------------------- | ------------------------------------------------------------------------------------- |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \sqrt{\texttt{arg}[i_1,\dots,i_n]}\f$ | /// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \sqrt{\texttt{arg}[i_1,\dots,i_n]}\f$ |
class Sqrt : public UnaryElementwiseArithmetic class Sqrt : public util::UnaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs a square operation. /// \brief Constructs a square operation.
...@@ -50,7 +50,9 @@ namespace ngraph ...@@ -50,7 +50,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 1) if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Sqrt>(new_args.at(0)); return std::make_shared<Sqrt>(new_args.at(0));
} }
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/binary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | -------------------------------------------------------------------------------------------------------------- | /// | ---------------------- | -------------------------------------------------------------------------------------------------------------- |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg0}[i_1,\dots,i_n] - \texttt{arg1}[i_1,\dots,i_n]\f$ | /// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg0}[i_1,\dots,i_n] - \texttt{arg1}[i_1,\dots,i_n]\f$ |
class Subtract : public BinaryElementwiseArithmetic class Subtract : public util::BinaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs an subtraction operation. /// \brief Constructs an subtraction operation.
...@@ -52,7 +52,9 @@ namespace ngraph ...@@ -52,7 +52,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 2) if (new_args.size() != 2)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Subtract>(new_args.at(0), new_args.at(1)); return std::make_shared<Subtract>(new_args.at(0), new_args.at(1));
} }
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -76,7 +76,7 @@ namespace ngraph ...@@ -76,7 +76,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ----------------------------------------- | ---------------------------------------------------------------------------------------------------------------- | /// | ----------------------------------------- | ---------------------------------------------------------------------------------------------------------------- |
/// | \f$N[\textit{delete}(A,d_1,\dots,d_n)]\f$ | The tensor \f$T\f$, where \f$T\f$ is the input tensor with the `reduction_axes` \f$A\f$ eliminated by summation. | /// | \f$N[\textit{delete}(A,d_1,\dots,d_n)]\f$ | The tensor \f$T\f$, where \f$T\f$ is the input tensor with the `reduction_axes` \f$A\f$ eliminated by summation. |
class Sum : public RequiresTensorViewArgs class Sum : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a summation operation. /// \brief Constructs a summation operation.
...@@ -89,7 +89,9 @@ namespace ngraph ...@@ -89,7 +89,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 1) if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Sum>(new_args.at(0), m_reduction_axes); return std::make_shared<Sum>(new_args.at(0), m_reduction_axes);
} }
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -35,7 +35,7 @@ namespace ngraph ...@@ -35,7 +35,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------ | /// | ---------------------- | ------------------------------------------------------------------------------------ |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \tan(\texttt{arg}[i_1,\dots,i_n])\f$ | /// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \tan(\texttt{arg}[i_1,\dots,i_n])\f$ |
class Tan : public UnaryElementwiseArithmetic class Tan : public util::UnaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs a tangent operation. /// \brief Constructs a tangent operation.
...@@ -50,7 +50,9 @@ namespace ngraph ...@@ -50,7 +50,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 1) if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Tan>(new_args.at(0)); return std::make_shared<Tan>(new_args.at(0));
} }
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -35,7 +35,7 @@ namespace ngraph ...@@ -35,7 +35,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------- | /// | ---------------------- | ------------------------------------------------------------------------------------- |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \tanh(\texttt{arg}[i_1,\dots,i_n])\f$ | /// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \tanh(\texttt{arg}[i_1,\dots,i_n])\f$ |
class Tanh : public UnaryElementwiseArithmetic class Tanh : public util::UnaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs a hyperbolic tangent operation. /// \brief Constructs a hyperbolic tangent operation.
...@@ -50,7 +50,9 @@ namespace ngraph ...@@ -50,7 +50,9 @@ namespace ngraph
const std::vector<std::shared_ptr<Node>>& new_args) const override const std::vector<std::shared_ptr<Node>>& new_args) const override
{ {
if (new_args.size() != 1) if (new_args.size() != 1)
{
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
}
return std::make_shared<Tanh>(new_args.at(0)); return std::make_shared<Tanh>(new_args.at(0));
} }
......
...@@ -17,15 +17,15 @@ ...@@ -17,15 +17,15 @@
#include <memory> #include <memory>
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/binary_elementwise.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::BinaryElementwise::BinaryElementwise(const std::string& node_type, op::util::BinaryElementwise::BinaryElementwise(const std::string& node_type,
const element::Type& result_element_type, const element::Type& result_element_type,
const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1) const std::shared_ptr<Node>& arg1)
: RequiresTensorViewArgs(node_type, Nodes{arg0, arg1}) : RequiresTensorViewArgs(node_type, Nodes{arg0, arg1})
{ {
auto& input_0 = get_inputs().at(0); auto& input_0 = get_inputs().at(0);
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph
{
namespace op
{
namespace util
{
/// \brief Abstract base class for elementwise binary operations, i.e., operations where the same
/// scalar binary operation is applied to each corresponding pair of elements in two same-shaped
/// input tensors.
///
/// For example, if the underlying operation (determined by the subclass) is \f$\mathit{op}(x,y)\f$, the input tensors
/// \f$[[x_0,y_0],[z_0,w_0]]\f$ and \f$[[x_1,y_1],[z_1,w_1]]\f$ will be mapped to \f$[[\mathit{op}(x_0,x_1),\mathit{op}(y_0,y_1)],[\mathit{op}(z_0,z_1),\mathit{op}(w_0,w_1)]]\f$.
///
/// ## Inputs
///
/// | | Type | Description |
/// | ------ | ----------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------- |
/// | `arg0` | \f$E_0[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape. Subclasses may impose restrictions on the element type \f$E_0\f$. |
/// | `arg1` | \f$E_1[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of the same shape as `arg0`. Subclasses may impose restrictions on the element type \f$E_1\f$. |
///
/// ## Output
///
/// | Type | Description |
/// | ------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$E_2[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{op}(\texttt{arg0}[i_1,\dots,i_n],\texttt{arg1}[i_1,\dots,i_n])\f$. This will always have the same shape as the input tensors, but subclasses must determine the element type \f$E_2\f$. |
class BinaryElementwise : public RequiresTensorViewArgs
{
protected:
/// \brief Constructs a biary elementwise operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
BinaryElementwise(const std::string& node_type,
const element::Type& result_element_type,
const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
};
}
}
}
...@@ -14,14 +14,15 @@ ...@@ -14,14 +14,15 @@
* limitations under the License. * limitations under the License.
*******************************************************************************/ *******************************************************************************/
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/binary_elementwise_arithmetic.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::BinaryElementwiseArithmetic::BinaryElementwiseArithmetic(const std::string& node_type, op::util::BinaryElementwiseArithmetic::BinaryElementwiseArithmetic(
const std::shared_ptr<Node>& arg0, const std::string& node_type,
const std::shared_ptr<Node>& arg1) const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
: BinaryElementwise(node_type, arg0->get_element_type(), arg0, arg1) : BinaryElementwise(node_type, arg0->get_element_type(), arg0, arg1)
{ {
if (arg0->get_element_type() != arg1->get_element_type()) if (arg0->get_element_type() != arg1->get_element_type())
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/ops/util/binary_elementwise.hpp"
namespace ngraph
{
namespace op
{
namespace util
{
/// \brief Abstract base class for elementwise binary arithmetic operations, i.e., operations where the same
/// scalar binary arithmetic operation is applied to each corresponding pair of elements in two same-shaped
/// input tensors.
///
/// For example, if the underlying arithmetic operation (determined by the subclass) is \f$\mathit{op}(x,y)\f$, the input tensors
/// \f$[[x_0,y_0],[z_0,w_0]]\f$ and \f$[[x_1,y_1],[z_1,w_1]]\f$ will be mapped to \f$[[\mathit{op}(x_0,x_1),\mathit{op}(y_0,y_1)],[\mathit{op}(z_0,z_1),\mathit{op}(w_0,w_1)]]\f$.
///
/// ## Inputs
///
/// | | Type | Description |
/// | ------ | --------------------------------- | ------------------------------------------------------------------------ |
/// | `arg0` | \f$N[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape. The element type \f$N\f$ may be any numeric type. |
/// | `arg1` | \f$N[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of the same shape and element type as `arg0`. |
///
/// ## Output
///
/// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{op}(\texttt{arg0}[i_1,\dots,i_n],\texttt{arg1}[i_1,\dots,i_n])\f$. This will always have the same shape and element type as the input tensors. |
class BinaryElementwiseArithmetic : public BinaryElementwise
{
public:
/// \brief Constructs a binary elementwise arithmetic operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
BinaryElementwiseArithmetic(const std::string& node_type,
const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
};
}
}
}
...@@ -14,14 +14,14 @@ ...@@ -14,14 +14,14 @@
* limitations under the License. * limitations under the License.
*******************************************************************************/ *******************************************************************************/
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/binary_elementwise_comparison.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::BinaryElementwiseComparison::BinaryElementwiseComparison(const string& node_type, op::util::BinaryElementwiseComparison::BinaryElementwiseComparison(const string& node_type,
const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1) const shared_ptr<Node>& arg1)
: BinaryElementwise(node_type, element::boolean, arg0, arg1) : BinaryElementwise(node_type, element::boolean, arg0, arg1)
{ {
if (arg0->get_element_type() != arg1->get_element_type()) if (arg0->get_element_type() != arg1->get_element_type())
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/ops/util/binary_elementwise.hpp"
namespace ngraph
{
namespace op
{
namespace util
{
/// \brief Abstract base class for elementwise binary comparison operations, i.e., operations where the same
/// scalar binary comparison operation is applied to each corresponding pair of elements in two same-shaped
/// input tensors.
///
/// For example, if the underlying comparison operation (determined by the subclass) is \f$\mathit{op}(x,y)\f$, the input tensors
/// \f$[[x_0,y_0],[z_0,w_0]]\f$ and \f$[[x_1,y_1],[z_1,w_1]]\f$ will be mapped to \f$[[\mathit{op}(x_0,x_1),\mathit{op}(y_0,y_1)],[\mathit{op}(z_0,z_1),\mathit{op}(w_0,w_1)]]\f$.
///
/// ## Inputs
///
/// | | Type | Description |
/// | ------ | --------------------------------- | ------------------------------------------------------ |
/// | `arg0` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape and element type. |
/// | `arg1` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of the same shape and element type as `arg0`. |
///
/// ## Output
///
/// | Type | Description |
/// | ---------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
/// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{op}(\texttt{arg0}[i_1,\dots,i_n],\texttt{arg1}[i_1,\dots,i_n])\f$. This will always have the same shape as the input tensors, and the element type `bool`. |
class BinaryElementwiseComparison : public BinaryElementwise
{
public:
/// \brief Constructs a binary elementwise comparison operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
BinaryElementwiseComparison(const std::string& node_type,
const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
};
}
}
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <algorithm>
#include <memory>
#include <sstream>
#include "ngraph/except.hpp"
#include "ngraph/ops/util/requires_tensor_view_args.hpp"
#include "ngraph/types/type.hpp"
using namespace ngraph;
using namespace std;
op::util::RequiresTensorViewArgs::RequiresTensorViewArgs(const std::string& node_type,
const Nodes& args)
: Op(node_type, args)
{
for (auto arg : args)
{
if (arg->get_output_size() != 1)
{
throw ngraph_error("Arguments for node type \"" + node_type +
"\" must be tensor views");
}
}
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/common.hpp"
#include "ngraph/ops/op.hpp"
namespace ngraph
{
namespace op
{
namespace util
{
/// \brief Abstract base class for ops on tensors views.
class RequiresTensorViewArgs : public ngraph::op::Op
{
protected:
/// \brief Constructs an operation on tensor view arguments.
///
/// \param args The nodes producing this node's input tensors.
RequiresTensorViewArgs(const std::string& node_type, const Nodes& args);
};
}
}
}
...@@ -16,14 +16,14 @@ ...@@ -16,14 +16,14 @@
#include <memory> #include <memory>
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::UnaryElementwise::UnaryElementwise(const std::string& node_type, op::util::UnaryElementwise::UnaryElementwise(const std::string& node_type,
const element::Type& result_element_type, const element::Type& result_element_type,
const std::shared_ptr<Node>& arg) const std::shared_ptr<Node>& arg)
: RequiresTensorViewArgs(node_type, Nodes{arg}) : RequiresTensorViewArgs(node_type, Nodes{arg})
{ {
auto& input = get_inputs().at(0); auto& input = get_inputs().at(0);
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
File mode changed from 100644 to 100755
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment