Commit 5ce09de0 authored by pthoreho's avatar pthoreho

Merge remote-tracking branch 'origin/master' into pruthvi/max_pooling

parents c6672b3d fbf5a0cf
# API Changes # API Changes
## Changes to ops
* The namespace `ngraph::op` is only for actual ops. Helpers have been moved into
`ngraph::op::util`:
+ `BinaryElementwiseArithmetic`
+ `BinaryElementwiseComparison`
+ `BinaryElementwise`
+ `RequiresTensorViewArgs`
+ `UnaryElementwiseArithmetic`
+ `UnaryElementwise`
Ops defined outside of nGraph core will need to get the base class from `ngraph::op::util` and
change the include file to `#include "ngraph/ops/util/requires_tensor_view_args.hpp"`, etc.
See any of the core ops for an example.
## Changes to convolution and pooling ops ## Changes to convolution and pooling ops
* Backprop ops have been added for convolution ops. * Backprop ops have been added for convolution ops.
......
.. constant.rst:
########
Constant
########
Description
===========
Literal constant tensor.
The output is a tensor initialized from the ``values`` attribute.
Attributes
----------
+-----------------+------------------------------+---------------------------------------+
| Name | Type | Notes |
+=================+==============================+=======================================+
| ``type`` | ``ngraph::element::type`` | The element type of the value |
| | | in the computation. |
+-----------------+------------------------------+---------------------------------------+
| ``shape`` | ``ngraph::Shape`` | The shape of the constant. |
+-----------------+------------------------------+---------------------------------------+
| ``values`` | ``const std::vector<T>&`` | Constant elements in row-major order. |
| | | T must be compatible with the element |
| | | type. |
+-----------------+------------------------------+---------------------------------------+
Outputs
-------
+-----------------+-------------------------+--------------------------------+
| Name | Element Type | Shape |
+=================+=========================+================================+
| ``output`` | ``type`` | ``shape`` |
+-----------------+-------------------------+--------------------------------+
C++ Interface
=============
.. doxygenclass:: ngraph::op::Constant
:members:
.. convert.rst:
#######
Convert
#######
Description
===========
Convert a tensor from one element type to another.
Inputs
------
+-----------------+-------------------------+--------------------------------+
| Name | Element Type | Shape |
+=================+=========================+================================+
| ``arg`` | Any | Any |
+-----------------+-------------------------+--------------------------------+
Attributes
----------
+------------------+---------------------------+---------------------------------+
| Name | Type | Notes |
+==================+===========================+=================================+
| ``element_type`` | ``ngraph::element::type`` | The element type of the result. |
+------------------+---------------------------+---------------------------------+
Outputs
-------
+-----------------+-------------------------+--------------------------------+
| Name | Element Type | Shape |
+=================+=========================+================================+
| ``output`` | ``element_type`` | Same as ``arg``. |
+-----------------+-------------------------+--------------------------------+
Backprop
========
.. math::
\overline{\texttt{arg}} \leftarrow \texttt{Convert}(\Delta,\texttt{arg->get_element_type()})
C++ Interface
=============
.. doxygenclass:: ngraph::op::Convert
:members:
.. cos.rst:
###
Cos
###
Description
===========
Elementwise cosine operation.
Produces a tensor of the same element type and shape as ``arg``,
where the value at each coordinate of ``output`` is the cosine of the
value at the corresponding coordinate of ``arg``.
Inputs
------
+-----------------+-------------------------+--------------------------------+
| Name | Element Type | Shape |
+=================+=========================+================================+
| ``arg`` | Any | Any |
+-----------------+-------------------------+--------------------------------+
Outputs
-------
+-----------------+-------------------------+--------------------------------+
| Name | Element Type | Shape |
+=================+=========================+================================+
| ``output`` | Same as ``arg`` | Same as ``arg``. |
+-----------------+-------------------------+--------------------------------+
Mathematical Definition
=======================
.. math::
\texttt{output}_{i_0, \ldots, i_{n-1}} = \cos(\texttt{arg}_{i_0, \ldots, i_{n-1}})
Backprop
========
.. math::
\overline{\texttt{arg}} \leftarrow -\Delta\ \sin(\texttt{arg})
C++ Interface
=============
.. doxygenclass:: ngraph::op::Cos
:members:
.. cosh.rst:
####
Cosh
####
Description
===========
Elementwise hyperbolic cosine operation.
Produces a tensor of the same element type and shape as ``arg``, where
the value at each coordinate of ``output`` is the hyperbolic cosine of
the value at the corresponding coordinate of ``arg``.
Inputs
------
+-----------------+-------------------------+--------------------------------+
| Name | Element Type | Shape |
+=================+=========================+================================+
| ``arg`` | Any | Any |
+-----------------+-------------------------+--------------------------------+
Outputs
-------
+-----------------+-------------------------+--------------------------------+
| Name | Element Type | Shape |
+=================+=========================+================================+
| ``output`` | Same as ``arg`` | Same as ``arg``. |
+-----------------+-------------------------+--------------------------------+
Mathematical Definition
=======================
.. math::
\texttt{output}_{i_0, \ldots, i_{n-1}} = \cosh(\texttt{arg}_{i_0, \ldots, i_{n-1}})
Backprop
========
.. math::
\overline{\texttt{arg}} \leftarrow \Delta\ \sinh(\texttt{arg})
C++ Interface
=============
.. doxygenclass:: ngraph::op::Cosh
:members:
...@@ -58,4 +58,9 @@ Not currently a comprehensive list. ...@@ -58,4 +58,9 @@ Not currently a comprehensive list.
broadcast.rst broadcast.rst
ceiling.rst ceiling.rst
concatenate.rst concatenate.rst
constant.rst
convert.rst
convolution.rst convolution.rst
cos.rst
cosh.rst
...@@ -37,9 +37,6 @@ set (SRC ...@@ -37,9 +37,6 @@ set (SRC
ops/add.cpp ops/add.cpp
ops/avg_pool.cpp ops/avg_pool.cpp
ops/batch_norm.cpp ops/batch_norm.cpp
ops/binary_elementwise_arithmetic.cpp
ops/binary_elementwise_comparison.cpp
ops/binary_elementwise.cpp
ops/broadcast.cpp ops/broadcast.cpp
ops/concatenate.cpp ops/concatenate.cpp
ops/constant.cpp ops/constant.cpp
...@@ -79,8 +76,12 @@ set (SRC ...@@ -79,8 +76,12 @@ set (SRC
ops/sum.cpp ops/sum.cpp
ops/tan.cpp ops/tan.cpp
ops/tanh.cpp ops/tanh.cpp
ops/unary_elementwise_arithmetic.cpp ops/util/binary_elementwise_arithmetic.cpp
ops/unary_elementwise.cpp ops/util/binary_elementwise_comparison.cpp
ops/util/binary_elementwise.cpp
ops/util/requires_tensor_view_args.cpp
ops/util/unary_elementwise_arithmetic.cpp
ops/util/unary_elementwise.cpp
pass/dump_sorted.cpp pass/dump_sorted.cpp
pass/graph_rewrite.cpp pass/graph_rewrite.cpp
pass/inliner.cpp pass/inliner.cpp
......
...@@ -208,19 +208,8 @@ void codegen::StaticCompiler::initialize() ...@@ -208,19 +208,8 @@ void codegen::StaticCompiler::initialize()
} }
// Enable various target features // Enable various target features
// Most of these are for Eigen
auto& TO = m_compiler->getInvocation().getTargetOpts(); auto& TO = m_compiler->getInvocation().getTargetOpts();
TO.CPU = sys::getHostCPUName(); TO.CPU = sys::getHostCPUName();
TO.FeaturesAsWritten.emplace_back("+sse");
TO.FeaturesAsWritten.emplace_back("+sse2");
TO.FeaturesAsWritten.emplace_back("+sse3");
TO.FeaturesAsWritten.emplace_back("+ssse3");
TO.FeaturesAsWritten.emplace_back("+sse4.1");
TO.FeaturesAsWritten.emplace_back("+sse4.2");
TO.FeaturesAsWritten.emplace_back("+avx");
TO.FeaturesAsWritten.emplace_back("+avx2");
TO.FeaturesAsWritten.emplace_back("+fma");
} }
codegen::StaticCompiler::~StaticCompiler() codegen::StaticCompiler::~StaticCompiler()
......
...@@ -26,45 +26,45 @@ ...@@ -26,45 +26,45 @@
using namespace std; using namespace std;
namespace nervana namespace ngraph
{ {
class thread_starter; class thread_starter;
} }
string nervana::logger::log_path; string ngraph::logger::log_path;
deque<string> nervana::logger::queue; deque<string> ngraph::logger::queue;
static mutex queue_mutex; static mutex queue_mutex;
static condition_variable queue_condition; static condition_variable queue_condition;
static unique_ptr<thread> queue_thread; static unique_ptr<thread> queue_thread;
static bool active = false; static bool active = false;
std::ostream& nervana::get_nil_stream() std::ostream& ngraph::get_nil_stream()
{ {
static std::stringstream nil; static std::stringstream nil;
return nil; return nil;
} }
class nervana::thread_starter class ngraph::thread_starter
{ {
public: public:
thread_starter() { nervana::logger::start(); } thread_starter() { ngraph::logger::start(); }
virtual ~thread_starter() { nervana::logger::stop(); } virtual ~thread_starter() { ngraph::logger::stop(); }
}; };
static nervana::thread_starter _starter; static ngraph::thread_starter _starter;
void nervana::logger::set_log_path(const string& path) void ngraph::logger::set_log_path(const string& path)
{ {
log_path = path; log_path = path;
} }
void nervana::logger::start() void ngraph::logger::start()
{ {
active = true; active = true;
queue_thread = unique_ptr<thread>(new thread(&thread_entry, nullptr)); queue_thread = unique_ptr<thread>(new thread(&thread_entry, nullptr));
} }
void nervana::logger::stop() void ngraph::logger::stop()
{ {
{ {
unique_lock<std::mutex> lk(queue_mutex); unique_lock<std::mutex> lk(queue_mutex);
...@@ -74,12 +74,12 @@ void nervana::logger::stop() ...@@ -74,12 +74,12 @@ void nervana::logger::stop()
queue_thread->join(); queue_thread->join();
} }
void nervana::logger::process_event(const string& s) void ngraph::logger::process_event(const string& s)
{ {
cout << s << "\n"; cout << s << "\n";
} }
void nervana::logger::thread_entry(void* param) void ngraph::logger::thread_entry(void* param)
{ {
unique_lock<std::mutex> lk(queue_mutex); unique_lock<std::mutex> lk(queue_mutex);
while (active) while (active)
...@@ -93,14 +93,14 @@ void nervana::logger::thread_entry(void* param) ...@@ -93,14 +93,14 @@ void nervana::logger::thread_entry(void* param)
} }
} }
void nervana::logger::log_item(const string& s) void ngraph::logger::log_item(const string& s)
{ {
unique_lock<std::mutex> lk(queue_mutex); unique_lock<std::mutex> lk(queue_mutex);
queue.push_back(s); queue.push_back(s);
queue_condition.notify_one(); queue_condition.notify_one();
} }
nervana::log_helper::log_helper(LOG_TYPE type, const char* file, int line, const char* func) ngraph::log_helper::log_helper(LOG_TYPE type, const char* file, int line, const char* func)
{ {
switch (type) switch (type)
{ {
...@@ -124,7 +124,7 @@ nervana::log_helper::log_helper(LOG_TYPE type, const char* file, int line, const ...@@ -124,7 +124,7 @@ nervana::log_helper::log_helper(LOG_TYPE type, const char* file, int line, const
_stream << "\t"; _stream << "\t";
} }
nervana::log_helper::~log_helper() ngraph::log_helper::~log_helper()
{ {
cout << _stream.str() << endl; cout << _stream.str() << endl;
// logger::log_item(_stream.str()); // logger::log_item(_stream.str());
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
namespace nervana namespace ngraph
{ {
class conststring class conststring
{ {
...@@ -93,30 +93,30 @@ namespace nervana ...@@ -93,30 +93,30 @@ namespace nervana
extern std::ostream& get_nil_stream(); extern std::ostream& get_nil_stream();
#define NGRAPH_ERR \ #define NGRAPH_ERR \
nervana::log_helper(nervana::LOG_TYPE::_LOG_TYPE_ERROR, \ ngraph::log_helper(ngraph::LOG_TYPE::_LOG_TYPE_ERROR, \
nervana::get_file_name(__FILE__), \ ngraph::get_file_name(__FILE__), \
__LINE__, \ __LINE__, \
__PRETTY_FUNCTION__) \ __PRETTY_FUNCTION__) \
.stream() .stream()
#define NGRAPH_WARN \ #define NGRAPH_WARN \
nervana::log_helper(nervana::LOG_TYPE::_LOG_TYPE_WARNING, \ ngraph::log_helper(ngraph::LOG_TYPE::_LOG_TYPE_WARNING, \
nervana::get_file_name(__FILE__), \ ngraph::get_file_name(__FILE__), \
__LINE__, \ __LINE__, \
__PRETTY_FUNCTION__) \ __PRETTY_FUNCTION__) \
.stream() .stream()
#define NGRAPH_INFO \ #define NGRAPH_INFO \
nervana::log_helper(nervana::LOG_TYPE::_LOG_TYPE_INFO, \ ngraph::log_helper(ngraph::LOG_TYPE::_LOG_TYPE_INFO, \
nervana::get_file_name(__FILE__), \ ngraph::get_file_name(__FILE__), \
__LINE__, \ __LINE__, \
__PRETTY_FUNCTION__) \ __PRETTY_FUNCTION__) \
.stream() .stream()
// #define NGRAPH_DEBUG \ // #define NGRAPH_DEBUG \
// nervana::log_helper(nervana::LOG_TYPE::_LOG_TYPE_DEBUG, \ // ngraph::log_helper(ngraph::LOG_TYPE::_LOG_TYPE_DEBUG, \
// nervana::get_file_name(__FILE__), \ // ngraph::get_file_name(__FILE__), \
// __LINE__, \ // __LINE__, \
// __PRETTY_FUNCTION__) \ // __PRETTY_FUNCTION__) \
// .stream() // .stream()
#define NGRAPH_DEBUG nervana::get_nil_stream() #define NGRAPH_DEBUG ngraph::get_nil_stream()
} }
...@@ -144,42 +144,6 @@ void Node::set_name(const string& name) ...@@ -144,42 +144,6 @@ void Node::set_name(const string& name)
} }
} }
void Node::assert_argument_list_equivalency(const Nodes& b)
{
bool arguments_equal = true;
if (this->m_arguments.size() == b.size())
{
for (size_t i = 0; i < this->m_arguments.size(); i++)
{
arguments_equal = arguments_equal && this->m_arguments.at(i) == b.at(i);
}
}
else
{
arguments_equal = false;
}
if (!arguments_equal)
{
std::cout << "node = " << this->get_name() << std::endl;
std::cout << "m_arguments" << std::endl;
for (auto arg : this->m_arguments)
{
std::cout << "arg = " << arg->get_name() << std::endl;
}
std::cout << "results" << std::endl;
for (auto arg : b)
{
std::cout << "arg = " << arg->get_name() << std::endl;
}
}
if (!arguments_equal)
{
throw "Arguments aren't equal";
}
}
std::shared_ptr<Node> Node::get_input_op(size_t index) std::shared_ptr<Node> Node::get_input_op(size_t index)
{ {
for (auto arg : m_arguments) for (auto arg : m_arguments)
...@@ -201,7 +165,10 @@ Nodes Node::get_input_ops() //const ...@@ -201,7 +165,10 @@ Nodes Node::get_input_ops() //const
result.push_back(i.get_output().get_node()); result.push_back(i.get_output().get_node());
} }
} }
assert_argument_list_equivalency(result); if (m_arguments != result)
{
throw ngraph_error("Arguments aren't equal: different values");
}
return result; return result;
} }
......
...@@ -170,8 +170,6 @@ namespace ngraph ...@@ -170,8 +170,6 @@ namespace ngraph
protected: protected:
void add_output(const element::Type& element_type, const Shape& shape); void add_output(const element::Type& element_type, const Shape& shape);
void assert_argument_list_equivalency(const Nodes& b);
bool test_identical(const Node&) const;
std::string m_node_type; std::string m_node_type;
std::multiset<Node*> m_users; std::multiset<Node*> m_users;
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include <memory> #include <memory>
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -26,7 +26,7 @@ namespace ngraph ...@@ -26,7 +26,7 @@ namespace ngraph
{ {
/// \brief Elementwise absolute value operation. /// \brief Elementwise absolute value operation.
/// ///
class Abs : public UnaryElementwiseArithmetic class Abs : public util::UnaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs an absolute value operation. /// \brief Constructs an absolute value operation.
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include <memory> #include <memory>
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -26,7 +26,7 @@ namespace ngraph ...@@ -26,7 +26,7 @@ namespace ngraph
{ {
/// \brief Elementwise inverse cosine (arccos) operation. /// \brief Elementwise inverse cosine (arccos) operation.
/// ///
class Acos : public UnaryElementwiseArithmetic class Acos : public util::UnaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs an arccos operation. /// \brief Constructs an arccos operation.
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include <memory> #include <memory>
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/binary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -26,7 +26,7 @@ namespace ngraph ...@@ -26,7 +26,7 @@ namespace ngraph
{ {
/// \brief Elementwise addition operation. /// \brief Elementwise addition operation.
/// ///
class Add : public BinaryElementwiseArithmetic class Add : public util::BinaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs an addition operation. /// \brief Constructs an addition operation.
......
...@@ -17,13 +17,13 @@ ...@@ -17,13 +17,13 @@
#ifdef NGRAPH_DISTRIBUTED #ifdef NGRAPH_DISTRIBUTED
#include <memory> #include <memory>
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
{ {
class AllReduce : public RequiresTensorViewArgs class AllReduce : public util::RequiresTensorViewArgs
{ {
public: public:
AllReduce(const std::shared_ptr<Node>& arg); AllReduce(const std::shared_ptr<Node>& arg);
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include <memory> #include <memory>
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -26,7 +26,7 @@ namespace ngraph ...@@ -26,7 +26,7 @@ namespace ngraph
{ {
/// \brief Elementwise inverse sine (arcsin) operation. /// \brief Elementwise inverse sine (arcsin) operation.
/// ///
class Asin : public UnaryElementwiseArithmetic class Asin : public util::UnaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs an arcsin operation. /// \brief Constructs an arcsin operation.
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include <memory> #include <memory>
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -26,7 +26,7 @@ namespace ngraph ...@@ -26,7 +26,7 @@ namespace ngraph
{ {
/// \brief Elementwise inverse tangent (arctan) operation. /// \brief Elementwise inverse tangent (arctan) operation.
/// ///
class Atan : public UnaryElementwiseArithmetic class Atan : public util::UnaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs an arctan operation. /// \brief Constructs an arctan operation.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -24,7 +24,7 @@ namespace ngraph ...@@ -24,7 +24,7 @@ namespace ngraph
{ {
/// \brief Batched average pooling operation, with optional padding and window stride. /// \brief Batched average pooling operation, with optional padding and window stride.
/// ///
class AvgPool : public RequiresTensorViewArgs class AvgPool : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a batched average pooling operation. /// \brief Constructs a batched average pooling operation.
...@@ -98,7 +98,7 @@ namespace ngraph ...@@ -98,7 +98,7 @@ namespace ngraph
Shape m_padding_above; Shape m_padding_above;
}; };
class AvgPoolBackprop : public RequiresTensorViewArgs class AvgPoolBackprop : public util::RequiresTensorViewArgs
{ {
public: public:
AvgPoolBackprop(const Shape& forward_arg_shape, AvgPoolBackprop(const Shape& forward_arg_shape,
......
...@@ -18,14 +18,14 @@ ...@@ -18,14 +18,14 @@
#include <memory> #include <memory>
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
{ {
class BatchNorm : public RequiresTensorViewArgs class BatchNorm : public util::RequiresTensorViewArgs
{ {
public: public:
BatchNorm(double eps, BatchNorm(double eps,
......
...@@ -16,14 +16,14 @@ ...@@ -16,14 +16,14 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
{ {
/// \brief Operation which "adds" axes to an input tensor, replicating elements from the input as needed along the new axes. /// \brief Operation which "adds" axes to an input tensor, replicating elements from the input as needed along the new axes.
class Broadcast : public RequiresTensorViewArgs class Broadcast : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a conversion operation. /// \brief Constructs a conversion operation.
......
...@@ -16,14 +16,14 @@ ...@@ -16,14 +16,14 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
{ {
/// \brief Elementwise ceiling operation. /// \brief Elementwise ceiling operation.
class Ceiling : public UnaryElementwiseArithmetic class Ceiling : public util::UnaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs a ceiling operation. /// \brief Constructs a ceiling operation.
......
...@@ -18,14 +18,14 @@ ...@@ -18,14 +18,14 @@
#include <memory> #include <memory>
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
{ {
/// \brief Concatenation operation. /// \brief Concatenation operation.
class Concat : public RequiresTensorViewArgs class Concat : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a concatenation operation. /// \brief Constructs a concatenation operation.
......
...@@ -29,20 +29,6 @@ namespace ngraph ...@@ -29,20 +29,6 @@ namespace ngraph
namespace op namespace op
{ {
/// \brief Class for constants. /// \brief Class for constants.
///
/// ## Parameters
///
/// | | Description |
/// | --------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
/// | `type` | The ngraph::element::Type of the tensor constant. |
/// | `shape` | The ngraph::Shape of the tensor constant. |
/// | `values` | A list of values to initialize the underlying tensor constant. |
///
/// ## Output
///
/// | Type | Description |
/// | ---------------------- | --------------------------------------------------------------------- |
/// | \f$E[d_1,\dots,d_n]\f$ | A constant tensor with the specified element type, shape, and values. |
class Constant : public Node class Constant : public Node
{ {
public: public:
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise.hpp"
#include "ngraph/types/type.hpp" #include "ngraph/types/type.hpp"
namespace ngraph namespace ngraph
...@@ -24,28 +24,7 @@ namespace ngraph ...@@ -24,28 +24,7 @@ namespace ngraph
namespace op namespace op
{ {
/// \brief Elementwise type conversion operation. /// \brief Elementwise type conversion operation.
/// class Convert : public util::UnaryElementwise
/// Each scalar in the input tensor is converted to the specified output element type. Note that the conversion may
/// result in loss of precision. For example, conversion from `float32` to `int32` is allowed.
///
/// ## Parameters
///
/// | | Description |
/// | -------------- | ---------------------------------------- |
/// | `element_type` | The element type \f$E'\f$ to convert to. |
///
/// ## Inputs
///
/// | | Type | Description |
/// | ----- | --------------------------------- | ----------------------------------------- |
/// | `arg` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape and element type. |
///
/// ## Output
///
/// | Type | Description |
/// | ----------------------- | --------------------------------------------------------------------------------------------------------- |
/// | \f$E'[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{convert}_{(E,E')}(\texttt{arg}[i_1,\dots,i_n])\f$ |
class Convert : public UnaryElementwise
{ {
public: public:
/// \brief Constructs a conversion operation. /// \brief Constructs a conversion operation.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -24,7 +24,7 @@ namespace ngraph ...@@ -24,7 +24,7 @@ namespace ngraph
{ {
/// \brief Batched convolution operation, with optional window dilation and stride. /// \brief Batched convolution operation, with optional window dilation and stride.
/// ///
class Convolution : public RequiresTensorViewArgs class Convolution : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a batched convolution operation. /// \brief Constructs a batched convolution operation.
...@@ -151,7 +151,7 @@ namespace ngraph ...@@ -151,7 +151,7 @@ namespace ngraph
}; };
/// \brief Data batch backprop for batched convolution operation. /// \brief Data batch backprop for batched convolution operation.
class ConvolutionBackpropData : public RequiresTensorViewArgs class ConvolutionBackpropData : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a batched-convolution data batch-backprop operation. /// \brief Constructs a batched-convolution data batch-backprop operation.
...@@ -246,7 +246,7 @@ namespace ngraph ...@@ -246,7 +246,7 @@ namespace ngraph
}; };
/// \brief Filters backprop for batched convolution operation. /// \brief Filters backprop for batched convolution operation.
class ConvolutionBackpropFilters : public RequiresTensorViewArgs class ConvolutionBackpropFilters : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a batched-convolution filter-backprop operation. /// \brief Constructs a batched-convolution filter-backprop operation.
......
...@@ -16,26 +16,14 @@ ...@@ -16,26 +16,14 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
{ {
/// \brief Elementwise cosine operation. /// \brief Elementwise cosine operation.
/// class Cos : public util::UnaryElementwiseArithmetic
/// ## Inputs
///
/// | | Type | Description |
/// | ----- | --------------------------------- | ----------------------------------------------- |
/// | `arg` | \f$N[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape and numeric element type. |
///
/// ## Output
///
/// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------ |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \cos(\texttt{arg}[i_1,\dots,i_n])\f$ |
class Cos : public UnaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs a cosine operation. /// \brief Constructs a cosine operation.
......
...@@ -16,26 +16,14 @@ ...@@ -16,26 +16,14 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
namespace op namespace op
{ {
/// \brief Elementwise hyperbolic cosine (cosh) operation. /// \brief Elementwise hyperbolic cosine (cosh) operation.
/// class Cosh : public util::UnaryElementwiseArithmetic
/// ## Inputs
///
/// | | Type | Description |
/// | ----- | --------------------------------- | ----------------------------------------------- |
/// | `arg` | \f$N[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape and numeric element type. |
///
/// ## Output
///
/// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------- |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \cosh(\texttt{arg}[i_1,\dots,i_n])\f$ |
class Cosh : public UnaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs a hyperbolic cosine operation. /// \brief Constructs a hyperbolic cosine operation.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/binary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------------------------------------------ | /// | ---------------------- | ------------------------------------------------------------------------------------------------------------------------ |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg0}[i_1,\dots,i_n] \mathbin{/} \texttt{arg1}[i_1,\dots,i_n]\f$ | /// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg0}[i_1,\dots,i_n] \mathbin{/} \texttt{arg1}[i_1,\dots,i_n]\f$ |
class Divide : public BinaryElementwiseArithmetic class Divide : public util::BinaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs a division operation. /// \brief Constructs a division operation.
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include <utility> #include <utility>
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -54,7 +54,7 @@ namespace ngraph ...@@ -54,7 +54,7 @@ namespace ngraph
/// | ---------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | /// | ---------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$E[d_1,\dots,d_n,d''_1,\dots,d''_p]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n,k_1,\dots,k_p] = \Sigma_{0 \le j_1 < d'_1, \dots, 0 \le j_m < d'_m}(\mathtt{arg0}[i_1,\dots,i_n,j_1,\dots,j_m] \cdot \mathtt{arg1}[j_1,\dots,j_m,k_1,\dots,k_p])\f$ or, if \f$m = 0\f$, \f$T[i_1,\dots,i_n,k_1,\dots,k_p] = \mathtt{arg0}[i_1,\dots,i_n] \cdot \mathtt{arg1}[k_1,\dots,k_p]\f$. | /// | \f$E[d_1,\dots,d_n,d''_1,\dots,d''_p]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n,k_1,\dots,k_p] = \Sigma_{0 \le j_1 < d'_1, \dots, 0 \le j_m < d'_m}(\mathtt{arg0}[i_1,\dots,i_n,j_1,\dots,j_m] \cdot \mathtt{arg1}[j_1,\dots,j_m,k_1,\dots,k_p])\f$ or, if \f$m = 0\f$, \f$T[i_1,\dots,i_n,k_1,\dots,k_p] = \mathtt{arg0}[i_1,\dots,i_n] \cdot \mathtt{arg1}[k_1,\dots,k_p]\f$. |
/// ///
class Dot : public RequiresTensorViewArgs class Dot : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a dot product operation. /// \brief Constructs a dot product operation.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/binary_elementwise_comparison.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------ | /// | ---------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------ |
/// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg0}[i_1,\dots,i_n] = \texttt{arg1}[i_1,\dots,i_n]\text{, else } 0\f$ | /// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg0}[i_1,\dots,i_n] = \texttt{arg1}[i_1,\dots,i_n]\text{, else } 0\f$ |
class Equal : public BinaryElementwiseComparison class Equal : public util::BinaryElementwiseComparison
{ {
public: public:
/// \brief Constructs an is-equal operation. /// \brief Constructs an is-equal operation.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -35,7 +35,7 @@ namespace ngraph ...@@ -35,7 +35,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------ | /// | ---------------------- | ------------------------------------------------------------------------------------ |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \exp(\texttt{arg}[i_1,\dots,i_n])\f$ | /// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \exp(\texttt{arg}[i_1,\dots,i_n])\f$ |
class Exp : public UnaryElementwiseArithmetic class Exp : public util::UnaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs an exponential operation. /// \brief Constructs an exponential operation.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -35,7 +35,7 @@ namespace ngraph ...@@ -35,7 +35,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | ---------------------------------------------------------------------------------------------- | /// | ---------------------- | ---------------------------------------------------------------------------------------------- |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \lfloor \texttt{arg}[i_1,\dots,i_n] \rfloor\f$ | /// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \lfloor \texttt{arg}[i_1,\dots,i_n] \rfloor\f$ |
class Floor : public UnaryElementwiseArithmetic class Floor : public util::UnaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs a floor operation. /// \brief Constructs a floor operation.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/binary_elementwise_comparison.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------- | /// | ---------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg0}[i_1,\dots,i_n] \gt \texttt{arg1}[i_1,\dots,i_n]\text{, else } 0\f$ | /// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg0}[i_1,\dots,i_n] \gt \texttt{arg1}[i_1,\dots,i_n]\text{, else } 0\f$ |
class Greater : public BinaryElementwiseComparison class Greater : public util::BinaryElementwiseComparison
{ {
public: public:
/// \brief Constructs a greater-than operation. /// \brief Constructs a greater-than operation.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/binary_elementwise_comparison.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------- | /// | ---------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg0}[i_1,\dots,i_n] \geq \texttt{arg1}[i_1,\dots,i_n]\text{, else } 0\f$ | /// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg0}[i_1,\dots,i_n] \geq \texttt{arg1}[i_1,\dots,i_n]\text{, else } 0\f$ |
class GreaterEq : public BinaryElementwiseComparison class GreaterEq : public util::BinaryElementwiseComparison
{ {
public: public:
/// \brief Constructs a greater-than-or-equal operation. /// \brief Constructs a greater-than-or-equal operation.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/binary_elementwise_comparison.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------- | /// | ---------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg0}[i_1,\dots,i_n] \lt \texttt{arg1}[i_1,\dots,i_n]\text{, else } 0\f$ | /// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg0}[i_1,\dots,i_n] \lt \texttt{arg1}[i_1,\dots,i_n]\text{, else } 0\f$ |
class Less : public BinaryElementwiseComparison class Less : public util::BinaryElementwiseComparison
{ {
public: public:
/// \brief Constructs a less-than operation. /// \brief Constructs a less-than operation.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/binary_elementwise_comparison.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------- | /// | ---------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg0}[i_1,\dots,i_n] \leq \texttt{arg1}[i_1,\dots,i_n]\text{, else } 0\f$ | /// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg0}[i_1,\dots,i_n] \leq \texttt{arg1}[i_1,\dots,i_n]\text{, else } 0\f$ |
class LessEq : public BinaryElementwiseComparison class LessEq : public util::BinaryElementwiseComparison
{ {
public: public:
/// \brief Constructs a less-than-or-equal operation. /// \brief Constructs a less-than-or-equal operation.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -35,7 +35,7 @@ namespace ngraph ...@@ -35,7 +35,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | ----------------------------------------------------------------------------------- | /// | ---------------------- | ----------------------------------------------------------------------------------- |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \ln(\texttt{arg}[i_1,\dots,i_n])\f$ | /// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \ln(\texttt{arg}[i_1,\dots,i_n])\f$ |
class Log : public UnaryElementwiseArithmetic class Log : public util::UnaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs a natural log operation. /// \brief Constructs a natural log operation.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -41,7 +41,7 @@ namespace ngraph ...@@ -41,7 +41,7 @@ namespace ngraph
/// T_\textit{out}[a,c,i_1,\dots,i_n] = \max_{j_1 = s_1 i_1, \dots, j_n = s_n i_n}^{j_1 = s_1 i_1 + w_1 - 1, \dots, j_n = s_n i_n + w_n - 1} (T_\textit{in}[a,c,j_1,\dots,j_n]) /// T_\textit{out}[a,c,i_1,\dots,i_n] = \max_{j_1 = s_1 i_1, \dots, j_n = s_n i_n}^{j_1 = s_1 i_1 + w_1 - 1, \dots, j_n = s_n i_n + w_n - 1} (T_\textit{in}[a,c,j_1,\dots,j_n])
/// \f] /// \f]
/// ///
class MaxPool : public RequiresTensorViewArgs class MaxPool : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a batched max pooling operation. /// \brief Constructs a batched max pooling operation.
...@@ -101,7 +101,7 @@ namespace ngraph ...@@ -101,7 +101,7 @@ namespace ngraph
Shape m_padding_above; Shape m_padding_above;
}; };
class MaxPoolBackprop : public RequiresTensorViewArgs class MaxPoolBackprop : public util::RequiresTensorViewArgs
{ {
public: public:
MaxPoolBackprop(const std::shared_ptr<Node>& arg_forward, MaxPoolBackprop(const std::shared_ptr<Node>& arg_forward,
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/binary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------------------------------------ | /// | ---------------------- | ------------------------------------------------------------------------------------------------------------------ |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \max(\texttt{arg0}[i_1,\dots,i_n],\texttt{arg1}[i_1,\dots,i_n])\f$ | /// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \max(\texttt{arg0}[i_1,\dots,i_n],\texttt{arg1}[i_1,\dots,i_n])\f$ |
class Maximum : public BinaryElementwiseArithmetic class Maximum : public util::BinaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs a maximum operation. /// \brief Constructs a maximum operation.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/binary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------------------------------------ | /// | ---------------------- | ------------------------------------------------------------------------------------------------------------------ |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \min(\texttt{arg0}[i_1,\dots,i_n],\texttt{arg1}[i_1,\dots,i_n])\f$ | /// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \min(\texttt{arg0}[i_1,\dots,i_n],\texttt{arg1}[i_1,\dots,i_n])\f$ |
class Minimum : public BinaryElementwiseArithmetic class Minimum : public util::BinaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs a minimum operation. /// \brief Constructs a minimum operation.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/binary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------------------------------------ | /// | ---------------------- | ------------------------------------------------------------------------------------------------------------------ |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg0}[i_1,\dots,i_n] \cdot \texttt{arg1}[i_1,\dots,i_n]\f$ | /// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg0}[i_1,\dots,i_n] \cdot \texttt{arg1}[i_1,\dots,i_n]\f$ |
class Multiply : public BinaryElementwiseArithmetic class Multiply : public util::BinaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs a multiplication operation. /// \brief Constructs a multiplication operation.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -35,7 +35,7 @@ namespace ngraph ...@@ -35,7 +35,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | --------------------------------------------------------------------------------- | /// | ---------------------- | --------------------------------------------------------------------------------- |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = -(\texttt{arg}[i_1,\dots,i_n])\f$ | /// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = -(\texttt{arg}[i_1,\dots,i_n])\f$ |
class Negative : public UnaryElementwiseArithmetic class Negative : public util::UnaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs a negation operation. /// \brief Constructs a negation operation.
......
...@@ -21,6 +21,6 @@ using namespace ngraph; ...@@ -21,6 +21,6 @@ using namespace ngraph;
using namespace std; using namespace std;
op::Not::Not(const shared_ptr<Node>& arg) op::Not::Not(const shared_ptr<Node>& arg)
: op::UnaryElementwise("Not", arg->get_element_type(), arg) : UnaryElementwise("Not", arg->get_element_type(), arg)
{ {
} }
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -35,7 +35,7 @@ namespace ngraph ...@@ -35,7 +35,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------------------- | -------------------------------------------------------------------------------------------------------------- | /// | ---------------------------------- | -------------------------------------------------------------------------------------------------------------- |
/// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg}[i_1,\dots,i_n] = 0\text{, else } 0\f$ | /// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg}[i_1,\dots,i_n] = 0\text{, else } 0\f$ |
class Not : public UnaryElementwise class Not : public util::UnaryElementwise
{ {
public: public:
/// \brief Constructs a logical negation operation. /// \brief Constructs a logical negation operation.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/binary_elementwise_comparison.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------- | /// | ---------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg0}[i_1,\dots,i_n] \neq \texttt{arg1}[i_1,\dots,i_n]\text{, else } 0\f$ | /// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg0}[i_1,\dots,i_n] \neq \texttt{arg1}[i_1,\dots,i_n]\text{, else } 0\f$ |
class NotEqual : public BinaryElementwiseComparison class NotEqual : public util::BinaryElementwiseComparison
{ {
public: public:
/// \brief Constructs a not-equal operation. /// \brief Constructs a not-equal operation.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -42,7 +42,7 @@ namespace ngraph ...@@ -42,7 +42,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | /// | ---------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$E[d_1,\dots,d_n]\f$ | The tensor \f$T'\f$, where \f$T'[i_1,\dots,i_{m-1},i_m,i_{m+1},\dots,i_n] = 1\f$ if \f$T[i_1,\dots,i_{m-1},i_{m+1},\dots,i_n] = i_m\f$, else \f$0\f$. However, \f$T'\f$ is undefined if any non-integral value or any out-of-bounds value is detected in the input tensor. | /// | \f$E[d_1,\dots,d_n]\f$ | The tensor \f$T'\f$, where \f$T'[i_1,\dots,i_{m-1},i_m,i_{m+1},\dots,i_n] = 1\f$ if \f$T[i_1,\dots,i_{m-1},i_{m+1},\dots,i_n] = i_m\f$, else \f$0\f$. However, \f$T'\f$ is undefined if any non-integral value or any out-of-bounds value is detected in the input tensor. |
class OneHot : public RequiresTensorViewArgs class OneHot : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a one-hot operation. /// \brief Constructs a one-hot operation.
......
...@@ -18,23 +18,14 @@ ...@@ -18,23 +18,14 @@
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include "ngraph/except.hpp" #include "ngraph/common.hpp"
#include "ngraph/ops/op.hpp" #include "ngraph/ops/op.hpp"
#include "ngraph/types/type.hpp" #include "ngraph/types/type.hpp"
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
op::RequiresTensorViewArgs::RequiresTensorViewArgs(const std::string& node_type, op::Op::Op(const std::string& node_type, const Nodes& args)
const std::vector<std::shared_ptr<Node>>& args)
: Node(node_type, args) : Node(node_type, args)
{ {
for (auto arg : args)
{
if (arg->get_output_size() != 1)
{
throw ngraph_error("Arguments for node type \"" + node_type +
"\" must be tensor views");
}
}
} }
This diff is collapsed.
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -56,7 +56,7 @@ namespace ngraph ...@@ -56,7 +56,7 @@ namespace ngraph
/// (Note that `below` and `above` here refer respectively to lower- or higher-numbered coordinate indices, and numbering starts at the upper-left corner; /// (Note that `below` and `above` here refer respectively to lower- or higher-numbered coordinate indices, and numbering starts at the upper-left corner;
/// thus inserting a row "below" actually inserts it at the "top" of the matrix.) /// thus inserting a row "below" actually inserts it at the "top" of the matrix.)
/// ///
class Pad : public RequiresTensorViewArgs class Pad : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a generic padding operation. /// \brief Constructs a generic padding operation.
......
...@@ -22,7 +22,7 @@ using namespace std; ...@@ -22,7 +22,7 @@ using namespace std;
using namespace ngraph; using namespace ngraph;
op::Parameter::Parameter(const ngraph::element::Type& element_type, const Shape& shape) op::Parameter::Parameter(const ngraph::element::Type& element_type, const Shape& shape)
: Node("Parameter", {}) : Op("Parameter", {})
{ {
add_output(element_type, shape); add_output(element_type, shape);
} }
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/node.hpp" #include "ngraph/ops/op.hpp"
#include "ngraph/types/type.hpp" #include "ngraph/types/type.hpp"
namespace ngraph namespace ngraph
...@@ -41,7 +41,7 @@ namespace ngraph ...@@ -41,7 +41,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ------- | --------------------------------------------------------------------------------------------------------------------------- | /// | ------- | --------------------------------------------------------------------------------------------------------------------------- |
/// | \f$T\f$ | The value of the parameter, supplied by the `FunctionCall` to this function or in the initial `ngraph::runtime::CallFrame`. | /// | \f$T\f$ | The value of the parameter, supplied by the `FunctionCall` to this function or in the initial `ngraph::runtime::CallFrame`. |
class Parameter : public Node class Parameter : public op::Op
{ {
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/binary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | -------------------------------------------------------------------------------------------------------------- | /// | ---------------------- | -------------------------------------------------------------------------------------------------------------- |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg0}[i_1,\dots,i_n]^{\texttt{arg1}[i_1,\dots,i_n]}\f$ | /// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg0}[i_1,\dots,i_n]^{\texttt{arg1}[i_1,\dots,i_n]}\f$ |
class Power : public BinaryElementwiseArithmetic class Power : public util::BinaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs an exponentiation operation. /// \brief Constructs an exponentiation operation.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -82,7 +82,7 @@ namespace ngraph ...@@ -82,7 +82,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ----------------------------------------- | ---------------------------------------------------------------------------------------------------------------- | /// | ----------------------------------------- | ---------------------------------------------------------------------------------------------------------------- |
/// | \f$E[\textit{delete}(A,d_1,\dots,d_n)]\f$ | The tensor \f$T\f$, where \f$T\f$ is the input tensor with the `reduction_axes` \f$A\f$ eliminated by reduction. | /// | \f$E[\textit{delete}(A,d_1,\dots,d_n)]\f$ | The tensor \f$T\f$, where \f$T\f$ is the input tensor with the `reduction_axes` \f$A\f$ eliminated by reduction. |
class Reduce : public RequiresTensorViewArgs class Reduce : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a reduction operation. /// \brief Constructs a reduction operation.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -50,7 +50,7 @@ namespace ngraph ...@@ -50,7 +50,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | /// | ------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$E[d'_1,\dots,d'_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{reduce}(\mathit{reduction\_function},\mathit{arg\_init},V)\f$ where \f$V\f$ is the set of values in the input tensor within the window defined by the lower bound \f$(s_1i_1,\dots,s_ni_n)\f$ and the noninclusive upper bound \f$(s_1i_1 + w_1,\dots,s_ni_n + w_n)\f$. | /// | \f$E[d'_1,\dots,d'_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{reduce}(\mathit{reduction\_function},\mathit{arg\_init},V)\f$ where \f$V\f$ is the set of values in the input tensor within the window defined by the lower bound \f$(s_1i_1,\dots,s_ni_n)\f$ and the noninclusive upper bound \f$(s_1i_1 + w_1,\dots,s_ni_n + w_n)\f$. |
class ReduceWindow : public RequiresTensorViewArgs class ReduceWindow : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a reduce-window operation. /// \brief Constructs a reduce-window operation.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/binary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -38,7 +38,7 @@ namespace ngraph ...@@ -38,7 +38,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | ----------------------------------------------------------------------------------------------------------------- | /// | ---------------------- | ----------------------------------------------------------------------------------------------------------------- |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg0}[i_1,\dots,i_n] \mod \texttt{arg1}[i_1,\dots,i_n]\f$ | /// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg0}[i_1,\dots,i_n] \mod \texttt{arg1}[i_1,\dots,i_n]\f$ |
class Remainder : public BinaryElementwiseArithmetic class Remainder : public util::BinaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs a remainder operation. /// \brief Constructs a remainder operation.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -45,7 +45,7 @@ namespace ngraph ...@@ -45,7 +45,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | /// | ---------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$E[d_1,\dots,d_n]\f$ | The tensor \f$T\f$ where \f$T[i_1,\dots,i_n] = \texttt{arg1}[j_1,\dots,j_n]\f$ if \f$j_1,\dots,j_n\f$ is in bounds for `arg1` and for all \f$m\f$, \f$i_m = l_m + j_m s_m\f$, otherwise \f$\texttt{arg0}[i_1,\dots,i_n]\f$. | /// | \f$E[d_1,\dots,d_n]\f$ | The tensor \f$T\f$ where \f$T[i_1,\dots,i_n] = \texttt{arg1}[j_1,\dots,j_n]\f$ if \f$j_1,\dots,j_n\f$ is in bounds for `arg1` and for all \f$m\f$, \f$i_m = l_m + j_m s_m\f$, otherwise \f$\texttt{arg0}[i_1,\dots,i_n]\f$. |
class ReplaceSlice : public RequiresTensorViewArgs class ReplaceSlice : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a tensor slice replacement operation. /// \brief Constructs a tensor slice replacement operation.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -55,7 +55,7 @@ namespace ngraph ...@@ -55,7 +55,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ------------------------ | ------------------------------------------------------------------------------------------------------ | /// | ------------------------ | ------------------------------------------------------------------------------------------------------ |
/// | \f$E[d'_1,\dots,d'_m]\f$ | The tensor \f$T\f$, where \f$T\f$ is the input tensor with its elements rearranged as described above. | /// | \f$E[d'_1,\dots,d'_m]\f$ | The tensor \f$T\f$, where \f$T\f$ is the input tensor with its elements rearranged as described above. |
class Reshape : public RequiresTensorViewArgs class Reshape : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a reshape operation. /// \brief Constructs a reshape operation.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -43,7 +43,7 @@ namespace ngraph ...@@ -43,7 +43,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | /// | ---------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$E[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg}[j_1,\dots,j_n]\f$ and \f$j_k = d_k - i_k - 1\f$ if axis \f$k\f$ is in the reverse set; else \f$j_k = i_k\f$. | /// | \f$E[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg}[j_1,\dots,j_n]\f$ and \f$j_k = d_k - i_k - 1\f$ if axis \f$k\f$ is in the reverse set; else \f$j_k = i_k\f$. |
class Reverse : public RequiresTensorViewArgs class Reverse : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a reverse operation. /// \brief Constructs a reverse operation.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -37,7 +37,7 @@ namespace ngraph ...@@ -37,7 +37,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- | /// | ---------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$E[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg1}[i_1,\dots,i_n]\text{ if }\texttt{arg0}[i_1,\dots,i_n] \neq 0\text{, else }\texttt{arg2}[i_1,\dots,i_n]\f$ | /// | \f$E[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg1}[i_1,\dots,i_n]\text{ if }\texttt{arg0}[i_1,\dots,i_n] \neq 0\text{, else }\texttt{arg2}[i_1,\dots,i_n]\f$ |
class Select : public RequiresTensorViewArgs class Select : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a selection operation. /// \brief Constructs a selection operation.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -68,7 +68,7 @@ namespace ngraph ...@@ -68,7 +68,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | -------------------- | /// | ---------------------- | -------------------- |
/// | \f$E[d_1,\dots,d_n]\f$ | See above algorithm. | /// | \f$E[d_1,\dots,d_n]\f$ | See above algorithm. |
class SelectAndScatter : public RequiresTensorViewArgs class SelectAndScatter : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a select-and-scatter operation. /// \brief Constructs a select-and-scatter operation.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -37,7 +37,7 @@ namespace ngraph ...@@ -37,7 +37,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------ | /// | ---------------------- | ------------------------------------------------------------------------------------ |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \text{sgn}(\texttt{arg}[i_1,\dots,i_n])\f$ | /// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \text{sgn}(\texttt{arg}[i_1,\dots,i_n])\f$ |
class Sign : public UnaryElementwiseArithmetic class Sign : public util::UnaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs an elementwise sign operation. /// \brief Constructs an elementwise sign operation.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -35,7 +35,7 @@ namespace ngraph ...@@ -35,7 +35,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------ | /// | ---------------------- | ------------------------------------------------------------------------------------ |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \sin(\texttt{arg}[i_1,\dots,i_n])\f$ | /// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \sin(\texttt{arg}[i_1,\dots,i_n])\f$ |
class Sin : public UnaryElementwiseArithmetic class Sin : public util::UnaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs a sine operation. /// \brief Constructs a sine operation.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -35,7 +35,7 @@ namespace ngraph ...@@ -35,7 +35,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------- | /// | ---------------------- | ------------------------------------------------------------------------------------- |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \sinh(\texttt{arg}[i_1,\dots,i_n])\f$ | /// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \sinh(\texttt{arg}[i_1,\dots,i_n])\f$ |
class Sinh : public UnaryElementwiseArithmetic class Sinh : public util::UnaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs a hyperbolic sine operation. /// \brief Constructs a hyperbolic sine operation.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -48,7 +48,7 @@ namespace ngraph ...@@ -48,7 +48,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ------------------------------------------------------------------------------ | --------------------------------- | /// | ------------------------------------------------------------------------------ | --------------------------------- |
/// | \f$E[d'_1,\dots,d'_n]\f$ where \f$d'_i = \lceil(u_i - l_i)\, /\, s_i\rceil\f$. | The tensor sliced from the input. | /// | \f$E[d'_1,\dots,d'_n]\f$ where \f$d'_i = \lceil(u_i - l_i)\, /\, s_i\rceil\f$. | The tensor sliced from the input. |
class Slice : public RequiresTensorViewArgs class Slice : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a tensor slice operation. /// \brief Constructs a tensor slice operation.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -35,7 +35,7 @@ namespace ngraph ...@@ -35,7 +35,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------- | /// | ---------------------- | ------------------------------------------------------------------------------------- |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \sqrt{\texttt{arg}[i_1,\dots,i_n]}\f$ | /// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \sqrt{\texttt{arg}[i_1,\dots,i_n]}\f$ |
class Sqrt : public UnaryElementwiseArithmetic class Sqrt : public util::UnaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs a square operation. /// \brief Constructs a square operation.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/binary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -36,7 +36,7 @@ namespace ngraph ...@@ -36,7 +36,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | -------------------------------------------------------------------------------------------------------------- | /// | ---------------------- | -------------------------------------------------------------------------------------------------------------- |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg0}[i_1,\dots,i_n] - \texttt{arg1}[i_1,\dots,i_n]\f$ | /// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg0}[i_1,\dots,i_n] - \texttt{arg1}[i_1,\dots,i_n]\f$ |
class Subtract : public BinaryElementwiseArithmetic class Subtract : public util::BinaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs an subtraction operation. /// \brief Constructs an subtraction operation.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -76,7 +76,7 @@ namespace ngraph ...@@ -76,7 +76,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ----------------------------------------- | ---------------------------------------------------------------------------------------------------------------- | /// | ----------------------------------------- | ---------------------------------------------------------------------------------------------------------------- |
/// | \f$N[\textit{delete}(A,d_1,\dots,d_n)]\f$ | The tensor \f$T\f$, where \f$T\f$ is the input tensor with the `reduction_axes` \f$A\f$ eliminated by summation. | /// | \f$N[\textit{delete}(A,d_1,\dots,d_n)]\f$ | The tensor \f$T\f$, where \f$T\f$ is the input tensor with the `reduction_axes` \f$A\f$ eliminated by summation. |
class Sum : public RequiresTensorViewArgs class Sum : public util::RequiresTensorViewArgs
{ {
public: public:
/// \brief Constructs a summation operation. /// \brief Constructs a summation operation.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -35,7 +35,7 @@ namespace ngraph ...@@ -35,7 +35,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------ | /// | ---------------------- | ------------------------------------------------------------------------------------ |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \tan(\texttt{arg}[i_1,\dots,i_n])\f$ | /// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \tan(\texttt{arg}[i_1,\dots,i_n])\f$ |
class Tan : public UnaryElementwiseArithmetic class Tan : public util::UnaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs a tangent operation. /// \brief Constructs a tangent operation.
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise_arithmetic.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -35,7 +35,7 @@ namespace ngraph ...@@ -35,7 +35,7 @@ namespace ngraph
/// | Type | Description | /// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------- | /// | ---------------------- | ------------------------------------------------------------------------------------- |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \tanh(\texttt{arg}[i_1,\dots,i_n])\f$ | /// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \tanh(\texttt{arg}[i_1,\dots,i_n])\f$ |
class Tanh : public UnaryElementwiseArithmetic class Tanh : public util::UnaryElementwiseArithmetic
{ {
public: public:
/// \brief Constructs a hyperbolic tangent operation. /// \brief Constructs a hyperbolic tangent operation.
......
...@@ -17,15 +17,15 @@ ...@@ -17,15 +17,15 @@
#include <memory> #include <memory>
#include "ngraph/log.hpp" #include "ngraph/log.hpp"
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/binary_elementwise.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::BinaryElementwise::BinaryElementwise(const std::string& node_type, op::util::BinaryElementwise::BinaryElementwise(const std::string& node_type,
const element::Type& result_element_type, const element::Type& result_element_type,
const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1) const std::shared_ptr<Node>& arg1)
: RequiresTensorViewArgs(node_type, Nodes{arg0, arg1}) : RequiresTensorViewArgs(node_type, Nodes{arg0, arg1})
{ {
auto& input_0 = get_inputs().at(0); auto& input_0 = get_inputs().at(0);
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph
{
namespace op
{
namespace util
{
/// \brief Abstract base class for elementwise binary operations, i.e., operations where the same
/// scalar binary operation is applied to each corresponding pair of elements in two same-shaped
/// input tensors.
///
/// For example, if the underlying operation (determined by the subclass) is \f$\mathit{op}(x,y)\f$, the input tensors
/// \f$[[x_0,y_0],[z_0,w_0]]\f$ and \f$[[x_1,y_1],[z_1,w_1]]\f$ will be mapped to \f$[[\mathit{op}(x_0,x_1),\mathit{op}(y_0,y_1)],[\mathit{op}(z_0,z_1),\mathit{op}(w_0,w_1)]]\f$.
///
/// ## Inputs
///
/// | | Type | Description |
/// | ------ | ----------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------- |
/// | `arg0` | \f$E_0[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape. Subclasses may impose restrictions on the element type \f$E_0\f$. |
/// | `arg1` | \f$E_1[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of the same shape as `arg0`. Subclasses may impose restrictions on the element type \f$E_1\f$. |
///
/// ## Output
///
/// | Type | Description |
/// | ------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$E_2[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{op}(\texttt{arg0}[i_1,\dots,i_n],\texttt{arg1}[i_1,\dots,i_n])\f$. This will always have the same shape as the input tensors, but subclasses must determine the element type \f$E_2\f$. |
class BinaryElementwise : public RequiresTensorViewArgs
{
protected:
/// \brief Constructs a biary elementwise operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
BinaryElementwise(const std::string& node_type,
const element::Type& result_element_type,
const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
};
}
}
}
...@@ -14,14 +14,15 @@ ...@@ -14,14 +14,15 @@
* limitations under the License. * limitations under the License.
*******************************************************************************/ *******************************************************************************/
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/binary_elementwise_arithmetic.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::BinaryElementwiseArithmetic::BinaryElementwiseArithmetic(const std::string& node_type, op::util::BinaryElementwiseArithmetic::BinaryElementwiseArithmetic(
const std::shared_ptr<Node>& arg0, const std::string& node_type,
const std::shared_ptr<Node>& arg1) const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1)
: BinaryElementwise(node_type, arg0->get_element_type(), arg0, arg1) : BinaryElementwise(node_type, arg0->get_element_type(), arg0, arg1)
{ {
if (arg0->get_element_type() != arg1->get_element_type()) if (arg0->get_element_type() != arg1->get_element_type())
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/ops/util/binary_elementwise.hpp"
namespace ngraph
{
namespace op
{
namespace util
{
/// \brief Abstract base class for elementwise binary arithmetic operations, i.e., operations where the same
/// scalar binary arithmetic operation is applied to each corresponding pair of elements in two same-shaped
/// input tensors.
///
/// For example, if the underlying arithmetic operation (determined by the subclass) is \f$\mathit{op}(x,y)\f$, the input tensors
/// \f$[[x_0,y_0],[z_0,w_0]]\f$ and \f$[[x_1,y_1],[z_1,w_1]]\f$ will be mapped to \f$[[\mathit{op}(x_0,x_1),\mathit{op}(y_0,y_1)],[\mathit{op}(z_0,z_1),\mathit{op}(w_0,w_1)]]\f$.
///
/// ## Inputs
///
/// | | Type | Description |
/// | ------ | --------------------------------- | ------------------------------------------------------------------------ |
/// | `arg0` | \f$N[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape. The element type \f$N\f$ may be any numeric type. |
/// | `arg1` | \f$N[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of the same shape and element type as `arg0`. |
///
/// ## Output
///
/// | Type | Description |
/// | ---------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{op}(\texttt{arg0}[i_1,\dots,i_n],\texttt{arg1}[i_1,\dots,i_n])\f$. This will always have the same shape and element type as the input tensors. |
class BinaryElementwiseArithmetic : public BinaryElementwise
{
public:
/// \brief Constructs a binary elementwise arithmetic operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
BinaryElementwiseArithmetic(const std::string& node_type,
const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
};
}
}
}
...@@ -14,14 +14,14 @@ ...@@ -14,14 +14,14 @@
* limitations under the License. * limitations under the License.
*******************************************************************************/ *******************************************************************************/
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/binary_elementwise_comparison.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::BinaryElementwiseComparison::BinaryElementwiseComparison(const string& node_type, op::util::BinaryElementwiseComparison::BinaryElementwiseComparison(const string& node_type,
const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1) const shared_ptr<Node>& arg1)
: BinaryElementwise(node_type, element::boolean, arg0, arg1) : BinaryElementwise(node_type, element::boolean, arg0, arg1)
{ {
if (arg0->get_element_type() != arg1->get_element_type()) if (arg0->get_element_type() != arg1->get_element_type())
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/ops/util/binary_elementwise.hpp"
namespace ngraph
{
namespace op
{
namespace util
{
/// \brief Abstract base class for elementwise binary comparison operations, i.e., operations where the same
/// scalar binary comparison operation is applied to each corresponding pair of elements in two same-shaped
/// input tensors.
///
/// For example, if the underlying comparison operation (determined by the subclass) is \f$\mathit{op}(x,y)\f$, the input tensors
/// \f$[[x_0,y_0],[z_0,w_0]]\f$ and \f$[[x_1,y_1],[z_1,w_1]]\f$ will be mapped to \f$[[\mathit{op}(x_0,x_1),\mathit{op}(y_0,y_1)],[\mathit{op}(z_0,z_1),\mathit{op}(w_0,w_1)]]\f$.
///
/// ## Inputs
///
/// | | Type | Description |
/// | ------ | --------------------------------- | ------------------------------------------------------ |
/// | `arg0` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape and element type. |
/// | `arg1` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of the same shape and element type as `arg0`. |
///
/// ## Output
///
/// | Type | Description |
/// | ---------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
/// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{op}(\texttt{arg0}[i_1,\dots,i_n],\texttt{arg1}[i_1,\dots,i_n])\f$. This will always have the same shape as the input tensors, and the element type `bool`. |
class BinaryElementwiseComparison : public BinaryElementwise
{
public:
/// \brief Constructs a binary elementwise comparison operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
BinaryElementwiseComparison(const std::string& node_type,
const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1);
};
}
}
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <algorithm>
#include <memory>
#include <sstream>
#include "ngraph/except.hpp"
#include "ngraph/ops/util/requires_tensor_view_args.hpp"
#include "ngraph/types/type.hpp"
using namespace ngraph;
using namespace std;
op::util::RequiresTensorViewArgs::RequiresTensorViewArgs(const std::string& node_type,
const Nodes& args)
: Op(node_type, args)
{
for (auto arg : args)
{
if (arg->get_output_size() != 1)
{
throw ngraph_error("Arguments for node type \"" + node_type +
"\" must be tensor views");
}
}
}
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/common.hpp"
#include "ngraph/ops/op.hpp"
namespace ngraph
{
namespace op
{
namespace util
{
/// \brief Abstract base class for ops on tensors views.
class RequiresTensorViewArgs : public ngraph::op::Op
{
protected:
/// \brief Constructs an operation on tensor view arguments.
///
/// \param args The nodes producing this node's input tensors.
RequiresTensorViewArgs(const std::string& node_type, const Nodes& args);
};
}
}
}
...@@ -16,14 +16,14 @@ ...@@ -16,14 +16,14 @@
#include <memory> #include <memory>
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::UnaryElementwise::UnaryElementwise(const std::string& node_type, op::util::UnaryElementwise::UnaryElementwise(const std::string& node_type,
const element::Type& result_element_type, const element::Type& result_element_type,
const std::shared_ptr<Node>& arg) const std::shared_ptr<Node>& arg)
: RequiresTensorViewArgs(node_type, Nodes{arg}) : RequiresTensorViewArgs(node_type, Nodes{arg})
{ {
auto& input = get_inputs().at(0); auto& input = get_inputs().at(0);
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph
{
namespace op
{
namespace util
{
/// \brief Abstract base class for elementwise unary operations, i.e., operations where the same
/// scalar operation is applied to each element.
///
/// For example, if the underlying operation (determined by the subclass) is \f$\mathit{op}(x)\f$, the input tensor
/// \f$[[x,y],[z,w]]\f$ will be mapped to \f$[[\mathit{op}(x),\mathit{op}(y)],[\mathit{op}(z),\mathit{op}(w)]]\f$.
///
/// ## Inputs
///
/// | | Type | Description |
/// | ----- | --------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------- |
/// | `arg` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape. Subclasses may impose restrictions on the element type \f$E\f$. |
///
/// ## Output
///
/// | Type | Description |
/// | ----------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$E'[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{op}(\texttt{arg}[i_1,\dots,i_n])\f$. This will always have the same shape as the input tensor, but subclasses must determine the element type \f$E'\f$. |
class UnaryElementwise : public RequiresTensorViewArgs
{
protected:
/// \brief Constructs a unary elementwise tensor operation.
///
/// \param arg Node that produces the input tensor.
UnaryElementwise(const std::string& node_type,
const element::Type& result_element_type,
const std::shared_ptr<Node>& arg);
};
}
}
}
...@@ -14,12 +14,12 @@ ...@@ -14,12 +14,12 @@
* limitations under the License. * limitations under the License.
*******************************************************************************/ *******************************************************************************/
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/unary_elementwise_arithmetic.hpp"
using namespace ngraph; using namespace ngraph;
op::UnaryElementwiseArithmetic::UnaryElementwiseArithmetic(const std::string& node_type, op::util::UnaryElementwiseArithmetic::UnaryElementwiseArithmetic(const std::string& node_type,
const std::shared_ptr<Node>& arg) const std::shared_ptr<Node>& arg)
: UnaryElementwise(node_type, arg->get_element_type(), arg) : UnaryElementwise(node_type, arg->get_element_type(), arg)
{ {
if (arg->get_element_type() == element::boolean) if (arg->get_element_type() == element::boolean)
......
/*******************************************************************************
* Copyright 2017-2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/ops/util/unary_elementwise.hpp"
namespace ngraph
{
namespace op
{
namespace util
{
/// \brief Abstract base class for elementwise unary arithmetic operations, i.e., operations where the same
/// scalar arithmetic operation is applied to each element.
///
/// For example, if the underlying operation (determined by the subclass) is \f$\mathit{op}(x)\f$, the input tensor
/// \f$[[x,y],[z,w]]\f$ will be mapped to \f$[[\mathit{op}(x),\mathit{op}(y)],[\mathit{op}(z),\mathit{op}(w)]]\f$.
///
/// ## Inputs
///
/// | | Type | Description |
/// | ----- | --------------------------------- | ------------------------------------------------------------------------ |
/// | `arg` | \f$N[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape. The element type \f$N\f$ may be any numeric type. |
///
/// ## Output
///
/// | Type | Description |
/// | ---------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \mathit{op}(\texttt{arg}[i_1,\dots,i_n])\f$. This will always have the same shape and element type as the input tensor. |
class UnaryElementwiseArithmetic : public UnaryElementwise
{
protected:
/// \brief Constructs a unary elementwise arithmetic operation.
///
/// \param arg Node that produces the input tensor.
UnaryElementwiseArithmetic(const std::string& node_type,
const std::shared_ptr<Node>& arg);
};
}
}
}
This diff is collapsed.
...@@ -24,12 +24,12 @@ ...@@ -24,12 +24,12 @@
#include "ngraph/runtime/cpu/cpu_external_function.hpp" #include "ngraph/runtime/cpu/cpu_external_function.hpp"
#include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp" #include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp"
#define EMITTER_DECL(E) \ #define EMITTER_DECL(op_name) \
E(ngraph::runtime::cpu::CPU_ExternalFunction* external_function, \ emit<op_name>(CPU_ExternalFunction * external_function, \
codegen::CodeWriter& writer, \ codegen::CodeWriter & writer, \
const ngraph::Node* node, \ const ngraph::Node* node, \
const std::vector<ngraph::runtime::cpu::TensorViewWrapper>& args, \ const std::vector<TensorViewWrapper>& args, \
const std::vector<ngraph::runtime::cpu::TensorViewWrapper>& out) const std::vector<TensorViewWrapper>& out)
namespace ngraph namespace ngraph
{ {
...@@ -40,72 +40,25 @@ namespace ngraph ...@@ -40,72 +40,25 @@ namespace ngraph
class CPU_Emitter class CPU_Emitter
{ {
public: public:
static void EMITTER_DECL(EmitNop); template <typename OP>
static void EMITTER_DECL(EmitAdd); static void emit(CPU_ExternalFunction* external_function,
#ifdef NGRAPH_DISTRIBUTED codegen::CodeWriter& writer,
static void EMITTER_DECL(EmitAllReduce); const ngraph::Node* node,
#endif const std::vector<TensorViewWrapper>& args,
static void EMITTER_DECL(EmitDot); const std::vector<TensorViewWrapper>& out)
static void EMITTER_DECL(EmitMultiply); {
static void EMITTER_DECL(EmitGetOutputElement); throw std::runtime_error("Unimplemented op in CPU emitter");
static void EMITTER_DECL(EmitXLAGetTupleElement); }
static void EMITTER_DECL(EmitTuple);
static void EMITTER_DECL(EmitAbs);
static void EMITTER_DECL(EmitConcat);
static void EMITTER_DECL(EmitDivide);
static void EMITTER_DECL(EmitEqual);
static void EMITTER_DECL(EmitGreater);
static void EMITTER_DECL(EmitGreaterEq);
static void EMITTER_DECL(EmitLess);
static void EMITTER_DECL(EmitLessEq);
static void EMITTER_DECL(EmitLog);
static void EMITTER_DECL(EmitMaximum);
static void EMITTER_DECL(EmitMinimum);
static void EMITTER_DECL(EmitNegative);
static void EMITTER_DECL(EmitNotEqual);
static void EMITTER_DECL(EmitSelect);
static void EMITTER_DECL(EmitSubtract);
static void EMITTER_DECL(EmitBroadcast);
static void EMITTER_DECL(EmitMatmulBias);
static void EMITTER_DECL(EmitConvert);
static void EMITTER_DECL(EmitConstant);
static void EMITTER_DECL(EmitReshape);
static void EMITTER_DECL(EmitFunctionCall);
static void EMITTER_DECL(EmitReduce);
static void EMITTER_DECL(EmitSign);
static void EMITTER_DECL(EmitSlice);
static void EMITTER_DECL(EmitSum);
static void EMITTER_DECL(EmitExp);
static void EMITTER_DECL(EmitSin);
static void EMITTER_DECL(EmitSinh);
static void EMITTER_DECL(EmitCos);
static void EMITTER_DECL(EmitCosh);
static void EMITTER_DECL(EmitTan);
static void EMITTER_DECL(EmitTanh);
static void EMITTER_DECL(EmitAsin);
static void EMITTER_DECL(EmitAcos);
static void EMITTER_DECL(EmitAtan);
static void EMITTER_DECL(EmitPower);
static void EMITTER_DECL(EmitReplaceSlice);
static void EMITTER_DECL(EmitOneHot);
static void EMITTER_DECL(EmitFloor);
static void EMITTER_DECL(EmitCeiling);
static void EMITTER_DECL(EmitSqrt);
static void EMITTER_DECL(EmitConvolution);
static void EMITTER_DECL(EmitConvolutionBackpropFilters);
static void EMITTER_DECL(EmitConvolutionBackpropData);
static void EMITTER_DECL(EmitNot);
static void EMITTER_DECL(EmitMaxPool);
static void EMITTER_DECL(EmitReverse);
static void EMITTER_DECL(EmitReduceWindow);
static void EMITTER_DECL(EmitSelectAndScatter);
static void EMITTER_DECL(EmitAvgPool);
static void EMITTER_DECL(EmitAvgPoolBackprop);
static void EMITTER_DECL(EmitPad);
static void EMITTER_DECL(EmitBatchNorm);
static void EMITTER_DECL(EmitMaxPoolBackprop);
static void EmitMKLDNNPreamble(codegen::CodeWriter& writer); static void nop(CPU_ExternalFunction* external_function,
codegen::CodeWriter& writer,
const ngraph::Node* node,
const std::vector<TensorViewWrapper>& args,
const std::vector<TensorViewWrapper>& out)
{
}
static void emit_mkldnn_preamble(codegen::CodeWriter& writer);
private: private:
static std::string emit_vector(const TensorViewWrapper&, static std::string emit_vector(const TensorViewWrapper&,
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#pragma once #pragma once
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
namespace ngraph namespace ngraph
{ {
...@@ -31,7 +31,7 @@ namespace ngraph ...@@ -31,7 +31,7 @@ namespace ngraph
/// \brief Layout Conversion /// \brief Layout Conversion
/// ///
/// Converts an input tensor to a tensor with the given layout descriptor /// Converts an input tensor to a tensor with the given layout descriptor
class ConvertLayout : public ngraph::op::RequiresTensorViewArgs class ConvertLayout : public ngraph::op::util::RequiresTensorViewArgs
{ {
public: public:
ConvertLayout( ConvertLayout(
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#pragma once #pragma once
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/ops/op.hpp" #include "ngraph/ops/util/requires_tensor_view_args.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include <memory> #include <memory>
...@@ -26,7 +26,7 @@ namespace ngraph ...@@ -26,7 +26,7 @@ namespace ngraph
{ {
namespace op namespace op
{ {
class MatmulBias : public RequiresTensorViewArgs class MatmulBias : public util::RequiresTensorViewArgs
{ {
public: public:
MatmulBias(std::shared_ptr<Node> W, MatmulBias(std::shared_ptr<Node> W,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment