Commit 6abeca2b authored by Scott Cyphers's avatar Scott Cyphers

Merge branch 's-barannikov/new_op_form/l-o_ops' into cyphers/s-barannikov

parents dfa5d4d1 3396681c
...@@ -19,10 +19,10 @@ ...@@ -19,10 +19,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Less::Less(const shared_ptr<Node>& arg0, const string op::Less::type_name{"Less"};
const shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob) op::Less::Less(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& autob)
: BinaryElementwiseComparison("Less", arg0, arg1, autob) : BinaryElementwiseComparison(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,13 +26,18 @@ namespace ngraph ...@@ -26,13 +26,18 @@ namespace ngraph
class Less : public util::BinaryElementwiseComparison class Less : public util::BinaryElementwiseComparison
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a less-than operation.
Less() = default;
/// \brief Constructs a less-than operation. /// \brief Constructs a less-than operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
Less(const std::shared_ptr<Node>& arg0, Less(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -19,10 +19,12 @@ ...@@ -19,10 +19,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::LessEq::LessEq(const shared_ptr<Node>& arg0, const string op::LessEq::type_name{"LessEq"};
const shared_ptr<Node>& arg1,
op::LessEq::LessEq(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob) const AutoBroadcastSpec& autob)
: BinaryElementwiseComparison("LessEq", arg0, arg1, autob) : BinaryElementwiseComparison(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,13 +26,18 @@ namespace ngraph ...@@ -26,13 +26,18 @@ namespace ngraph
class LessEq : public util::BinaryElementwiseComparison class LessEq : public util::BinaryElementwiseComparison
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a less-than-or-equal operation.
LessEq() = default;
/// \brief Constructs a less-than-or-equal operation. /// \brief Constructs a less-than-or-equal operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
LessEq(const std::shared_ptr<Node>& arg0, LessEq(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -20,8 +20,10 @@ ...@@ -20,8 +20,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Log::Log(const shared_ptr<Node>& arg) const string op::Log::type_name{"Log"};
: UnaryElementwiseArithmetic("Log", arg)
op::Log::Log(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,10 +26,15 @@ namespace ngraph ...@@ -26,10 +26,15 @@ namespace ngraph
class Log : public util::UnaryElementwiseArithmetic class Log : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a natural log operation.
Log() = default;
/// \brief Constructs a natural log operation. /// \brief Constructs a natural log operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Log(const std::shared_ptr<Node>& arg); Log(const Output<Node>& arg);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -20,12 +20,14 @@ ...@@ -20,12 +20,14 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::LRN::LRN(const std::shared_ptr<Node>& arg, double alpha, double beta, double bias, size_t nsize) const string op::LRN::type_name{"LRN"};
: UnaryElementwiseArithmetic("LRN", arg)
op::LRN::LRN(const Output<Node>& arg, double alpha, double beta, double bias, size_t size)
: UnaryElementwiseArithmetic(arg)
, m_alpha(alpha) , m_alpha(alpha)
, m_beta(beta) , m_beta(beta)
, m_bias(bias) , m_bias(bias)
, m_size(nsize) , m_size(size)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -38,23 +38,29 @@ namespace ngraph ...@@ -38,23 +38,29 @@ namespace ngraph
class LRN : public util::UnaryElementwiseArithmetic class LRN : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a LRN operation.
LRN() = default;
/// \brief Constructs a LRN operation. /// \brief Constructs a LRN operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
LRN(const std::shared_ptr<Node>& arg, LRN(const Output<Node>& arg, double alpha, double beta, double bias, size_t size);
double alpha,
double beta,
double bias,
size_t size);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
double get_alpha() const { return m_alpha; } double get_alpha() const { return m_alpha; }
void set_alpha(double alpha) { m_alpha = alpha; }
double get_beta() const { return m_beta; } double get_beta() const { return m_beta; }
void set_beta(double beta) { m_beta = beta; }
double get_bias() const { return m_bias; } double get_bias() const { return m_bias; }
void set_bias(double bias) { m_bias = bias; }
size_t get_nsize() const { return m_size; } size_t get_nsize() const { return m_size; }
void set_nsize(size_t size) { m_size = size; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
......
...@@ -22,10 +22,6 @@ using namespace ngraph; ...@@ -22,10 +22,6 @@ using namespace ngraph;
const string op::Max::type_name{"Max"}; const string op::Max::type_name{"Max"};
op::Max::Max()
{
}
op::Max::Max(const Output<Node>& arg, const AxisSet& reduction_axes) op::Max::Max(const Output<Node>& arg, const AxisSet& reduction_axes)
: ArithmeticReduction(arg, reduction_axes) : ArithmeticReduction(arg, reduction_axes)
{ {
......
...@@ -30,7 +30,7 @@ namespace ngraph ...@@ -30,7 +30,7 @@ namespace ngraph
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs a "max" reduction operation. /// \brief Constructs a "max" reduction operation.
Max(); Max() = default;
/// \brief Constructs a max-reduction operation. /// \brief Constructs a max-reduction operation.
/// ///
/// \param arg The tensor to be reduced. /// \param arg The tensor to be reduced.
......
...@@ -25,14 +25,16 @@ ...@@ -25,14 +25,16 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::MaxPool::MaxPool(const shared_ptr<Node>& arg, const string op::MaxPool::type_name{"MaxPool"};
op::MaxPool::MaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above, const Shape& padding_above,
const PadType& pad_type, const PadType& pad_type,
bool ceil_mode) bool ceil_mode)
: Op("MaxPool", check_single_output_args({arg})) : Op({arg})
, m_window_shape(window_shape) , m_window_shape(window_shape)
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
, m_padding_below(padding_below) , m_padding_below(padding_below)
...@@ -43,7 +45,7 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg, ...@@ -43,7 +45,7 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::MaxPool::MaxPool(const shared_ptr<Node>& arg, op::MaxPool::MaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
...@@ -54,7 +56,7 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg, ...@@ -54,7 +56,7 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg,
{ {
} }
op::MaxPool::MaxPool(const shared_ptr<Node>& arg, op::MaxPool::MaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
...@@ -121,14 +123,14 @@ void op::MaxPool::validate_and_infer_types() ...@@ -121,14 +123,14 @@ void op::MaxPool::validate_and_infer_types()
m_ceil_mode)); m_ceil_mode));
} }
op::MaxPool::MaxPool(const shared_ptr<Node>& arg, op::MaxPool::MaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides) const Strides& window_movement_strides)
: MaxPool(arg, window_shape, window_movement_strides, Shape(), Shape()) : MaxPool(arg, window_shape, window_movement_strides, Shape(), Shape())
{ {
} }
op::MaxPool::MaxPool(const shared_ptr<Node>& arg, const Shape& window_shape) op::MaxPool::MaxPool(const Output<Node>& arg, const Shape& window_shape)
: MaxPool(arg, window_shape, Strides(), Shape(), Shape()) : MaxPool(arg, window_shape, Strides(), Shape(), Shape())
{ {
} }
...@@ -145,13 +147,15 @@ shared_ptr<Node> op::MaxPool::copy_with_new_args(const NodeVector& new_args) con ...@@ -145,13 +147,15 @@ shared_ptr<Node> op::MaxPool::copy_with_new_args(const NodeVector& new_args) con
m_ceil_mode); m_ceil_mode);
} }
op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward, const string op::MaxPoolBackprop::type_name{"MaxPoolBackprop"};
const shared_ptr<Node>& delta,
op::MaxPoolBackprop::MaxPoolBackprop(const Output<Node>& arg_forward,
const Output<Node>& delta,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above) const Shape& padding_above)
: Op("MaxPoolBackprop", check_single_output_args({arg_forward, delta})) : Op({arg_forward, delta})
, m_window_shape(window_shape) , m_window_shape(window_shape)
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
, m_padding_below(padding_below) , m_padding_below(padding_below)
...@@ -160,14 +164,14 @@ op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward, ...@@ -160,14 +164,14 @@ op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward, op::MaxPoolBackprop::MaxPoolBackprop(const Output<Node>& arg_forward,
const shared_ptr<Node>& delta, const Output<Node>& delta,
const shared_ptr<Node>& result_forward, const Output<Node>& result_forward,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above) const Shape& padding_above)
: Op("MaxPoolBackprop", check_single_output_args({arg_forward, delta, result_forward})) : Op({arg_forward, delta, result_forward})
, m_window_shape(window_shape) , m_window_shape(window_shape)
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
, m_padding_below(padding_below) , m_padding_below(padding_below)
......
...@@ -28,6 +28,13 @@ namespace ngraph ...@@ -28,6 +28,13 @@ namespace ngraph
class MaxPool : public Op class MaxPool : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a batched max pooling operation.
MaxPool() = default;
/// \brief Constructs a batched max pooling operation. /// \brief Constructs a batched max pooling operation.
/// ///
/// \param arg The node producing the input data batch tensor. /// \param arg The node producing the input data batch tensor.
...@@ -37,7 +44,7 @@ namespace ngraph ...@@ -37,7 +44,7 @@ namespace ngraph
/// \param padding_above The above-padding shape. /// \param padding_above The above-padding shape.
/// \param pad_type The pad type for automatically computing padding sizes /// \param pad_type The pad type for automatically computing padding sizes
/// \param ceil_mode Whether to use ceiling while computing output shape. /// \param ceil_mode Whether to use ceiling while computing output shape.
MaxPool(const std::shared_ptr<Node>& arg, MaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
...@@ -53,7 +60,7 @@ namespace ngraph ...@@ -53,7 +60,7 @@ namespace ngraph
/// \param padding_below The below-padding shape. /// \param padding_below The below-padding shape.
/// \param padding_above The above-padding shape. /// \param padding_above The above-padding shape.
/// \param pad_type The pad type for automatically computing padding sizes /// \param pad_type The pad type for automatically computing padding sizes
MaxPool(const std::shared_ptr<Node>& arg, MaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
...@@ -67,7 +74,7 @@ namespace ngraph ...@@ -67,7 +74,7 @@ namespace ngraph
/// \param window_movement_strides The window movement strides. /// \param window_movement_strides The window movement strides.
/// \param padding_below The below-padding shape. /// \param padding_below The below-padding shape.
/// \param padding_above The above-padding shape. /// \param padding_above The above-padding shape.
MaxPool(const std::shared_ptr<Node>& arg, MaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
...@@ -80,7 +87,7 @@ namespace ngraph ...@@ -80,7 +87,7 @@ namespace ngraph
/// \param arg The node producing the input data batch tensor. /// \param arg The node producing the input data batch tensor.
/// \param window_shape The window shape. /// \param window_shape The window shape.
/// \param window_movement_strides The window movement strides. /// \param window_movement_strides The window movement strides.
MaxPool(const std::shared_ptr<Node>& arg, MaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides); const Strides& window_movement_strides);
...@@ -88,23 +95,32 @@ namespace ngraph ...@@ -88,23 +95,32 @@ namespace ngraph
/// ///
/// \param arg The node producing the input data batch tensor. /// \param arg The node producing the input data batch tensor.
/// \param window_shape The window shape. /// \param window_shape The window shape.
MaxPool(const std::shared_ptr<Node>& arg, const Shape& window_shape); MaxPool(const Output<Node>& arg, const Shape& window_shape);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
/// \return The window shape. /// \return The window shape.
const Shape& get_window_shape() const { return m_window_shape; } const Shape& get_window_shape() const { return m_window_shape; }
void set_window_shape(const Shape& window_shape) { m_window_shape = window_shape; }
/// \return The window movement strides. /// \return The window movement strides.
const Strides& get_window_movement_strides() const { return m_window_movement_strides; } const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
void set_window_movement_strides(const Strides& window_movement_strides)
{
m_window_movement_strides = window_movement_strides;
}
/// \return The below-padding shape. /// \return The below-padding shape.
const Shape& get_padding_below() const { return m_padding_below; } const Shape& get_padding_below() const { return m_padding_below; }
void set_padding_below(const Shape& padding_below) { m_padding_below = padding_below; }
/// \return The above-padding shape. /// \return The above-padding shape.
const Shape& get_padding_above() const { return m_padding_above; } const Shape& get_padding_above() const { return m_padding_above; }
void set_adding_above(const Shape& padding_above) { m_padding_above = padding_above; }
/// \return The pad type for pooling. /// \return The pad type for pooling.
const PadType& get_pad_type() const { return m_pad_type; } const PadType& get_pad_type() const { return m_pad_type; }
void set_pad_type(const PadType& pad_type) { m_pad_type = pad_type; }
/// \return The ceiling mode being used for output shape computations /// \return The ceiling mode being used for output shape computations
bool get_ceil_mode() const { return m_ceil_mode; } bool get_ceil_mode() const { return m_ceil_mode; }
void set_ceil_mode(bool ceil_mode) { m_ceil_mode = ceil_mode; }
/// \return The default value for MaxPool. /// \return The default value for MaxPool.
virtual std::shared_ptr<Node> get_default_value() const override virtual std::shared_ptr<Node> get_default_value() const override
{ {
...@@ -126,16 +142,22 @@ namespace ngraph ...@@ -126,16 +142,22 @@ namespace ngraph
class MaxPoolBackprop : public Op class MaxPoolBackprop : public Op
{ {
public: public:
MaxPoolBackprop(const std::shared_ptr<Node>& arg_forward, NGRAPH_API
const std::shared_ptr<Node>& delta, static const std::string type_name;
const std::string& description() const override { return type_name; }
MaxPoolBackprop() = default;
MaxPoolBackprop(const Output<Node>& arg_forward,
const Output<Node>& delta,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above); const Shape& padding_above);
MaxPoolBackprop(const std::shared_ptr<Node>& arg_forward, MaxPoolBackprop(const Output<Node>& arg_forward,
const std::shared_ptr<Node>& delta, const Output<Node>& delta,
const std::shared_ptr<Node>& result_forward, const Output<Node>& result_forward,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
...@@ -147,9 +169,17 @@ namespace ngraph ...@@ -147,9 +169,17 @@ namespace ngraph
void validate_and_infer_types() override; void validate_and_infer_types() override;
const Shape& get_window_shape() const { return m_window_shape; } const Shape& get_window_shape() const { return m_window_shape; }
void set_window_shape(const Shape& window_shape) { m_window_shape = window_shape; }
const Strides& get_window_movement_strides() const { return m_window_movement_strides; } const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
void set_window_movement_strides(const Strides& window_movement_strides)
{
m_window_movement_strides = window_movement_strides;
}
const Shape& get_padding_below() const { return m_padding_below; } const Shape& get_padding_below() const { return m_padding_below; }
void set_padding_below(const Shape& padding_below) { m_padding_below = padding_below; }
const Shape& get_padding_above() const { return m_padding_above; } const Shape& get_padding_above() const { return m_padding_above; }
void set_padding_above(const Shape& padding_above) { m_padding_above = padding_above; }
protected: protected:
Shape m_window_shape; Shape m_window_shape;
Strides m_window_movement_strides; Strides m_window_movement_strides;
......
...@@ -25,10 +25,12 @@ ...@@ -25,10 +25,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Maximum::Maximum(const shared_ptr<Node>& arg0, const string op::Maximum::type_name{"Maximum"};
const shared_ptr<Node>& arg1,
op::Maximum::Maximum(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob) const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic("Maximum", arg0, arg1, autob) : BinaryElementwiseArithmetic(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,13 +26,18 @@ namespace ngraph ...@@ -26,13 +26,18 @@ namespace ngraph
class Maximum : public util::BinaryElementwiseArithmetic class Maximum : public util::BinaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a maximum operation.
Maximum() = default;
/// \brief Constructs a maximum operation. /// \brief Constructs a maximum operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
Maximum(const std::shared_ptr<Node>& arg0, Maximum(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -22,10 +22,6 @@ using namespace ngraph; ...@@ -22,10 +22,6 @@ using namespace ngraph;
const string op::Min::type_name{"Min"}; const string op::Min::type_name{"Min"};
op::Min::Min()
{
}
op::Min::Min(const Output<Node>& arg, const AxisSet& reduction_axes) op::Min::Min(const Output<Node>& arg, const AxisSet& reduction_axes)
: ArithmeticReduction(arg, reduction_axes) : ArithmeticReduction(arg, reduction_axes)
{ {
......
...@@ -30,7 +30,7 @@ namespace ngraph ...@@ -30,7 +30,7 @@ namespace ngraph
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs a "min" reduction operation. /// \brief Constructs a "min" reduction operation.
Min(); Min() = default;
/// \brief Constructs a min-reduction operation. /// \brief Constructs a min-reduction operation.
/// ///
/// \param arg The tensor to be reduced. /// \param arg The tensor to be reduced.
......
...@@ -25,10 +25,12 @@ ...@@ -25,10 +25,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Minimum::Minimum(const shared_ptr<Node>& arg0, const string op::Minimum::type_name{"Minimum"};
const shared_ptr<Node>& arg1,
op::Minimum::Minimum(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob) const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic("Minimum", arg0, arg1, autob) : BinaryElementwiseArithmetic(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,13 +26,18 @@ namespace ngraph ...@@ -26,13 +26,18 @@ namespace ngraph
class Minimum : public util::BinaryElementwiseArithmetic class Minimum : public util::BinaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a minimum operation.
Minimum() = default;
/// \brief Constructs a minimum operation. /// \brief Constructs a minimum operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
Minimum(const std::shared_ptr<Node>& arg0, Minimum(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -19,10 +19,12 @@ ...@@ -19,10 +19,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Multiply::Multiply(const shared_ptr<Node>& arg0, const string op::Multiply::type_name{"Multiply"};
const shared_ptr<Node>& arg1,
op::Multiply::Multiply(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob) const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic("Multiply", arg0, arg1, autob) : BinaryElementwiseArithmetic(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -49,7 +51,7 @@ void op::Multiply::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec ...@@ -49,7 +51,7 @@ void op::Multiply::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec
adjoints.add_delta(y, x * delta); adjoints.add_delta(y, x * delta);
} }
shared_ptr<Node> ngraph::operator*(const shared_ptr<Node> arg0, const shared_ptr<Node> arg1) shared_ptr<Node> ngraph::operator*(const Output<Node>& arg0, const Output<Node>& arg1)
{ {
return make_shared<op::Multiply>(arg0, arg1); return make_shared<op::Multiply>(arg0, arg1);
} }
...@@ -26,13 +26,18 @@ namespace ngraph ...@@ -26,13 +26,18 @@ namespace ngraph
class Multiply : public util::BinaryElementwiseArithmetic class Multiply : public util::BinaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a multiplication operation.
Multiply() = default;
/// \brief Constructs a multiplication operation. /// \brief Constructs a multiplication operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
Multiply(const std::shared_ptr<Node>& arg0, Multiply(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
...@@ -45,6 +50,5 @@ namespace ngraph ...@@ -45,6 +50,5 @@ namespace ngraph
}; };
}; };
std::shared_ptr<ngraph::Node> operator*(const std::shared_ptr<ngraph::Node> arg0, std::shared_ptr<Node> operator*(const Output<Node>& arg0, const Output<Node>& arg1);
const std::shared_ptr<ngraph::Node> arg1);
} }
...@@ -19,8 +19,10 @@ ...@@ -19,8 +19,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Negative::Negative(const shared_ptr<Node>& arg) const string op::Negative::type_name{"Negative"};
: UnaryElementwiseArithmetic("Negative", arg)
op::Negative::Negative(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -40,7 +42,7 @@ void op::Negative::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec ...@@ -40,7 +42,7 @@ void op::Negative::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec
adjoints.add_delta(x, -delta); adjoints.add_delta(x, -delta);
} }
shared_ptr<Node> ngraph::operator-(const shared_ptr<Node> arg0) shared_ptr<Node> ngraph::operator-(const Output<Node>& arg0)
{ {
return make_shared<op::Negative>(arg0); return make_shared<op::Negative>(arg0);
} }
...@@ -26,17 +26,23 @@ namespace ngraph ...@@ -26,17 +26,23 @@ namespace ngraph
class Negative : public util::UnaryElementwiseArithmetic class Negative : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a negative operation.
Negative() = default;
/// \brief Constructs a negative operation. /// \brief Constructs a negative operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Negative(const std::shared_ptr<Node>& arg); Negative(const Output<Node>& arg);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
}; };
} }
std::shared_ptr<ngraph::Node> operator-(const std::shared_ptr<ngraph::Node> arg0); std::shared_ptr<Node> operator-(const Output<Node>& arg0);
} }
...@@ -20,8 +20,10 @@ ...@@ -20,8 +20,10 @@
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
op::Not::Not(const shared_ptr<Node>& arg) const string op::Not::type_name{"Not"};
: Op("Not", check_single_output_args({arg}))
op::Not::Not(const Output<Node>& arg)
: Op({arg})
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,10 +26,15 @@ namespace ngraph ...@@ -26,10 +26,15 @@ namespace ngraph
class Not : public Op class Not : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a logical negation operation.
Not() = default;
/// \brief Constructs a logical negation operation. /// \brief Constructs a logical negation operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Not(const std::shared_ptr<Node>& arg); Not(const Output<Node>& arg);
void validate_and_infer_types() override; void validate_and_infer_types() override;
......
...@@ -19,10 +19,12 @@ ...@@ -19,10 +19,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::NotEqual::NotEqual(const shared_ptr<Node>& arg0, const string op::NotEqual::type_name{"NotEqual"};
const shared_ptr<Node>& arg1,
op::NotEqual::NotEqual(const Output<Node> &arg0,
const Output<Node> &arg1,
const AutoBroadcastSpec& autob) const AutoBroadcastSpec& autob)
: BinaryElementwiseComparison("NotEqual", arg0, arg1, autob) : BinaryElementwiseComparison(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,13 +26,18 @@ namespace ngraph ...@@ -26,13 +26,18 @@ namespace ngraph
class NotEqual : public util::BinaryElementwiseComparison class NotEqual : public util::BinaryElementwiseComparison
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a not-equal operation.
NotEqual() = default;
/// \brief Constructs a not-equal operation. /// \brief Constructs a not-equal operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
NotEqual(const std::shared_ptr<Node>& arg0, NotEqual(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -20,8 +20,10 @@ ...@@ -20,8 +20,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::OneHot::OneHot(const shared_ptr<Node>& arg, const PartialShape& shape, size_t one_hot_axis) const string op::OneHot::type_name{"OneHot"};
: Op("OneHot", check_single_output_args({arg}))
op::OneHot::OneHot(const Output<Node>& arg, const PartialShape& shape, size_t one_hot_axis)
: Op({arg})
, m_shape(shape) , m_shape(shape)
, m_one_hot_axis(one_hot_axis) , m_one_hot_axis(one_hot_axis)
{ {
......
...@@ -45,14 +45,17 @@ namespace ngraph ...@@ -45,14 +45,17 @@ namespace ngraph
class OneHot : public Op class OneHot : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a one-hot operation.
OneHot() = default;
/// \brief Constructs a one-hot operation. /// \brief Constructs a one-hot operation.
/// ///
/// \param arg Node that produces the input tensor to be one-hot encoded. /// \param arg Node that produces the input tensor to be one-hot encoded.
/// \param shape The shape of the output tensor, including the new one-hot axis. /// \param shape The shape of the output tensor, including the new one-hot axis.
/// \param one_hot_axis The index within the output shape of the new one-hot axis. /// \param one_hot_axis The index within the output shape of the new one-hot axis.
OneHot(const std::shared_ptr<Node>& arg, OneHot(const Output<Node>& arg, const PartialShape& shape, size_t one_hot_axis);
const PartialShape& shape,
size_t one_hot_axis);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -60,6 +63,8 @@ namespace ngraph ...@@ -60,6 +63,8 @@ namespace ngraph
/// \return The index of the one-hot axis. /// \return The index of the one-hot axis.
size_t get_one_hot_axis() const { return m_one_hot_axis; } size_t get_one_hot_axis() const { return m_one_hot_axis; }
void set_one_hot_axis(size_t one_hot_axis) { m_one_hot_axis = one_hot_axis; }
protected: protected:
PartialShape m_shape; PartialShape m_shape;
size_t m_one_hot_axis; size_t m_one_hot_axis;
......
...@@ -19,10 +19,10 @@ ...@@ -19,10 +19,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Or::Or(const shared_ptr<Node>& arg0, const string op::Or::type_name{"Or"};
const shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob) op::Or::Or(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& autob)
: BinaryElementwiseLogical("Or", arg0, arg1, autob) : BinaryElementwiseLogical(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -29,6 +29,9 @@ namespace ngraph ...@@ -29,6 +29,9 @@ namespace ngraph
class Or : public util::BinaryElementwiseLogical class Or : public util::BinaryElementwiseLogical
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a logical-or operation. /// \brief Constructs a logical-or operation.
/// ///
/// \param arg0 Node that produces the first input tensor.<br> /// \param arg0 Node that produces the first input tensor.<br>
...@@ -39,8 +42,8 @@ namespace ngraph ...@@ -39,8 +42,8 @@ namespace ngraph
/// ///
/// Output `[d0, ...]` /// Output `[d0, ...]`
/// ///
Or(const std::shared_ptr<Node>& arg0, Or(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment