Unverified Commit a6da7f1b authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Merge pull request #3278 from NervanaSystems/cyphers/s-barannikov

Cyphers/s barannikov
parents d92dcdfe ff2e7fe4
...@@ -214,7 +214,7 @@ namespace ngraph ...@@ -214,7 +214,7 @@ namespace ngraph
virtual bool is_constant() const; virtual bool is_constant() const;
virtual bool is_null() const { return false; } virtual bool is_null() const { return false; }
virtual bool is_op() const { return false; } virtual bool is_op() const { return false; }
virtual bool is_commutative() { return false; } virtual bool is_commutative() const { return false; }
virtual bool is_dynamic() const; virtual bool is_dynamic() const;
virtual bool has_state() const { return false; } virtual bool has_state() const { return false; }
size_t get_instance_id() const { return m_instance_id; } size_t get_instance_id() const { return m_instance_id; }
......
...@@ -49,7 +49,7 @@ void op::Add::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& ...@@ -49,7 +49,7 @@ void op::Add::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
adjoints.add_delta(y, delta); adjoints.add_delta(y, delta);
} }
shared_ptr<Node> ngraph::operator+(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1) shared_ptr<Node> ngraph::operator+(const Output<Node>& arg0, const Output<Node>& arg1)
{ {
return make_shared<op::Add>(arg0, arg1); return make_shared<op::Add>(arg0, arg1);
} }
...@@ -51,13 +51,12 @@ namespace ngraph ...@@ -51,13 +51,12 @@ namespace ngraph
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
virtual bool is_commutative() override { return true; }
}; };
} }
std::shared_ptr<ngraph::Node> operator+(const std::shared_ptr<ngraph::Node>& arg0, std::shared_ptr<Node> operator+(const Output<Node>& arg0, const Output<Node>& arg1);
const std::shared_ptr<ngraph::Node>& arg1);
} }
...@@ -51,8 +51,7 @@ namespace ngraph ...@@ -51,8 +51,7 @@ namespace ngraph
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
protected: virtual bool is_commutative() const override { return true; }
virtual bool is_commutative() override { return true; }
}; };
} }
} }
...@@ -22,12 +22,15 @@ ...@@ -22,12 +22,15 @@
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/validation_util.hpp" #include "ngraph/validation_util.hpp"
const std::string ngraph::op::BatchNormTraining::type_name{"BatchNormTraining"}; using namespace std;
using namespace ngraph;
ngraph::op::BatchNormTraining::BatchNormTraining(Output<ngraph::Node> input, const string op::BatchNormTraining::type_name{"BatchNormTraining"};
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta, op::BatchNormTraining::BatchNormTraining(const Output<Node>& input,
double epsilon) const Output<Node>& gamma,
const Output<Node>& beta,
double epsilon)
: Op({gamma, beta, input}) : Op({gamma, beta, input})
, m_epsilon(epsilon) , m_epsilon(epsilon)
{ {
...@@ -35,17 +38,17 @@ ngraph::op::BatchNormTraining::BatchNormTraining(Output<ngraph::Node> input, ...@@ -35,17 +38,17 @@ ngraph::op::BatchNormTraining::BatchNormTraining(Output<ngraph::Node> input,
} }
// DEPRECATED // DEPRECATED
ngraph::op::BatchNormTraining::BatchNormTraining(double eps, op::BatchNormTraining::BatchNormTraining(double eps,
Output<ngraph::Node> gamma, const Output<Node>& gamma,
Output<ngraph::Node> beta, const Output<Node>& beta,
Output<ngraph::Node> input) const Output<Node>& input)
: Op({gamma, beta, input}) : Op({gamma, beta, input})
, m_epsilon(eps) , m_epsilon(eps)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
void ngraph::op::BatchNormTraining::validate_and_infer_types() void op::BatchNormTraining::validate_and_infer_types()
{ {
element::Type result_et; element::Type result_et;
PartialShape result_batch_shape; PartialShape result_batch_shape;
...@@ -66,16 +69,15 @@ void ngraph::op::BatchNormTraining::validate_and_infer_types() ...@@ -66,16 +69,15 @@ void ngraph::op::BatchNormTraining::validate_and_infer_types()
set_output_type(2, result_et, result_channel_shape); set_output_type(2, result_et, result_channel_shape);
} }
std::shared_ptr<ngraph::Node> std::shared_ptr<Node> op::BatchNormTraining::copy_with_new_args(const NodeVector& new_args) const
ngraph::op::BatchNormTraining::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return std::make_shared<BatchNormTraining>( return std::make_shared<BatchNormTraining>(
new_args.at(2), new_args.at(0), new_args.at(1), m_epsilon); new_args.at(2), new_args.at(0), new_args.at(1), m_epsilon);
} }
void ngraph::op::BatchNormTraining::generate_adjoints(autodiff::Adjoints& adjoints, void op::BatchNormTraining::generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) const NodeVector& deltas)
{ {
auto gamma = input(0).get_source_output(); auto gamma = input(0).get_source_output();
auto beta = input(1).get_source_output(); auto beta = input(1).get_source_output();
...@@ -102,14 +104,14 @@ void ngraph::op::BatchNormTraining::generate_adjoints(autodiff::Adjoints& adjoin ...@@ -102,14 +104,14 @@ void ngraph::op::BatchNormTraining::generate_adjoints(autodiff::Adjoints& adjoin
adjoints.add_delta(beta, dbeta); adjoints.add_delta(beta, dbeta);
} }
const std::string ngraph::op::BatchNormInference::type_name{"BatchNormInference"}; const string op::BatchNormInference::type_name{"BatchNormInference"};
ngraph::op::BatchNormInference::BatchNormInference(Output<ngraph::Node> input, op::BatchNormInference::BatchNormInference(const Output<Node>& input,
Output<ngraph::Node> gamma, const Output<Node>& gamma,
Output<ngraph::Node> beta, const Output<Node>& beta,
Output<ngraph::Node> mean, const Output<Node>& mean,
Output<ngraph::Node> variance, const Output<Node>& variance,
double epsilon) double epsilon)
: Op({gamma, beta, input, mean, variance}) : Op({gamma, beta, input, mean, variance})
, m_epsilon(epsilon) , m_epsilon(epsilon)
{ {
...@@ -117,19 +119,19 @@ ngraph::op::BatchNormInference::BatchNormInference(Output<ngraph::Node> input, ...@@ -117,19 +119,19 @@ ngraph::op::BatchNormInference::BatchNormInference(Output<ngraph::Node> input,
} }
// DEPRECATED // DEPRECATED
ngraph::op::BatchNormInference::BatchNormInference(double eps, op::BatchNormInference::BatchNormInference(double eps,
Output<ngraph::Node> gamma, const Output<Node>& gamma,
Output<ngraph::Node> beta, const Output<Node>& beta,
Output<ngraph::Node> input, const Output<Node>& input,
Output<ngraph::Node> mean, const Output<Node>& mean,
Output<ngraph::Node> variance) const Output<Node>& variance)
: Op({gamma, beta, input, mean, variance}) : Op({gamma, beta, input, mean, variance})
, m_epsilon(eps) , m_epsilon(eps)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
void ngraph::op::BatchNormInference::validate_and_infer_types() void op::BatchNormInference::validate_and_infer_types()
{ {
element::Type result_et; element::Type result_et;
PartialShape result_batch_shape; PartialShape result_batch_shape;
...@@ -152,23 +154,22 @@ void ngraph::op::BatchNormInference::validate_and_infer_types() ...@@ -152,23 +154,22 @@ void ngraph::op::BatchNormInference::validate_and_infer_types()
set_output_type(0, result_et, result_batch_shape); set_output_type(0, result_et, result_batch_shape);
} }
std::shared_ptr<ngraph::Node> std::shared_ptr<Node> op::BatchNormInference::copy_with_new_args(const NodeVector& new_args) const
ngraph::op::BatchNormInference::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return std::make_shared<BatchNormInference>( return std::make_shared<BatchNormInference>(
new_args.at(2), new_args.at(0), new_args.at(1), new_args.at(3), new_args.at(4), m_epsilon); new_args.at(2), new_args.at(0), new_args.at(1), new_args.at(3), new_args.at(4), m_epsilon);
} }
const std::string ngraph::op::BatchNormTrainingBackprop::type_name{"BatchNormTrainingBackprop"}; const string op::BatchNormTrainingBackprop::type_name{"BatchNormTrainingBackprop"};
ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(Output<ngraph::Node> input, op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(const Output<Node>& input,
Output<ngraph::Node> gamma, const Output<Node>& gamma,
Output<ngraph::Node> beta, const Output<Node>& beta,
Output<ngraph::Node> mean, const Output<Node>& mean,
Output<ngraph::Node> variance, const Output<Node>& variance,
Output<ngraph::Node> delta, const Output<Node>& delta,
double epsilon) double epsilon)
: Op({gamma, beta, input, mean, variance, delta}) : Op({gamma, beta, input, mean, variance, delta})
, m_epsilon(epsilon) , m_epsilon(epsilon)
...@@ -177,13 +178,13 @@ ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(Output<ngraph:: ...@@ -177,13 +178,13 @@ ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(Output<ngraph::
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(double epsilon, op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(double epsilon,
Output<ngraph::Node> gamma, const Output<Node>& gamma,
Output<ngraph::Node> beta, const Output<Node>& beta,
Output<ngraph::Node> input, const Output<Node>& input,
Output<ngraph::Node> mean, const Output<Node>& mean,
Output<ngraph::Node> variance, const Output<Node>& variance,
Output<ngraph::Node> delta) const Output<Node>& delta)
: Op({gamma, beta, input, mean, variance, delta}) : Op({gamma, beta, input, mean, variance, delta})
, m_epsilon(epsilon) , m_epsilon(epsilon)
...@@ -192,7 +193,7 @@ ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(double epsilon, ...@@ -192,7 +193,7 @@ ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(double epsilon,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
void ngraph::op::BatchNormTrainingBackprop::validate_and_infer_types() void op::BatchNormTrainingBackprop::validate_and_infer_types()
{ {
PartialShape input_and_delta_shape{get_input_partial_shape(INPUT_DATA)}; PartialShape input_and_delta_shape{get_input_partial_shape(INPUT_DATA)};
...@@ -239,8 +240,8 @@ void ngraph::op::BatchNormTrainingBackprop::validate_and_infer_types() ...@@ -239,8 +240,8 @@ void ngraph::op::BatchNormTrainingBackprop::validate_and_infer_types()
set_output_type(2, result_et, result_channel_shape); set_output_type(2, result_et, result_channel_shape);
} }
std::shared_ptr<ngraph::Node> std::shared_ptr<Node>
ngraph::op::BatchNormTrainingBackprop::copy_with_new_args(const NodeVector& new_args) const op::BatchNormTrainingBackprop::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return std::make_shared<op::BatchNormTrainingBackprop>(new_args.at(2), return std::make_shared<op::BatchNormTrainingBackprop>(new_args.at(2),
......
...@@ -39,9 +39,9 @@ namespace ngraph ...@@ -39,9 +39,9 @@ namespace ngraph
/// \param gamma gamma scaling for normalized value. [C] /// \param gamma gamma scaling for normalized value. [C]
/// \param beta bias added to the scaled normalized value [C] /// \param beta bias added to the scaled normalized value [C]
/// \param epsilon Avoids divsion by 0 if input has 0 variance /// \param epsilon Avoids divsion by 0 if input has 0 variance
BatchNormTraining(Output<Node> input, BatchNormTraining(const Output<Node>& input,
Output<Node> gamma, const Output<Node>& gamma,
Output<Node> beta, const Output<Node>& beta,
double epsilon); double epsilon);
NGRAPH_DEPRECATED_DOC NGRAPH_DEPRECATED_DOC
...@@ -66,9 +66,9 @@ namespace ngraph ...@@ -66,9 +66,9 @@ namespace ngraph
/// output[2]: shall have rank 1, with the same span as input's channel axis. /// output[2]: shall have rank 1, with the same span as input's channel axis.
NGRAPH_DEPRECATED("Use another constructor") NGRAPH_DEPRECATED("Use another constructor")
BatchNormTraining(double eps, BatchNormTraining(double eps,
Output<Node> gamma, const Output<Node>& gamma,
Output<Node> beta, const Output<Node>& beta,
Output<Node> input); const Output<Node>& input);
void validate_and_infer_types() override; void validate_and_infer_types() override;
...@@ -101,11 +101,11 @@ namespace ngraph ...@@ -101,11 +101,11 @@ namespace ngraph
/// \param mean value for mean normalization [C] /// \param mean value for mean normalization [C]
/// \param variance value for variance normalization [C] /// \param variance value for variance normalization [C]
/// \param epsilon Avoids divsion by 0 if input has 0 variance /// \param epsilon Avoids divsion by 0 if input has 0 variance
BatchNormInference(Output<ngraph::Node> input, BatchNormInference(const Output<Node>& input,
Output<ngraph::Node> gamma, const Output<Node>& gamma,
Output<ngraph::Node> beta, const Output<Node>& beta,
Output<ngraph::Node> mean, const Output<Node>& mean,
Output<ngraph::Node> variance, const Output<Node>& variance,
double epsilon); double epsilon);
NGRAPH_DEPRECATED_DOC NGRAPH_DEPRECATED_DOC
...@@ -128,11 +128,11 @@ namespace ngraph ...@@ -128,11 +128,11 @@ namespace ngraph
/// output: shall have the same shape as 'input'. /// output: shall have the same shape as 'input'.
NGRAPH_DEPRECATED("Use another constructor") NGRAPH_DEPRECATED("Use another constructor")
BatchNormInference(double eps, BatchNormInference(double eps,
Output<ngraph::Node> gamma, const Output<Node>& gamma,
Output<ngraph::Node> beta, const Output<Node>& beta,
Output<ngraph::Node> input, const Output<Node>& input,
Output<ngraph::Node> mean, const Output<Node>& mean,
Output<ngraph::Node> variance); const Output<Node>& variance);
void validate_and_infer_types() override; void validate_and_infer_types() override;
...@@ -165,24 +165,23 @@ namespace ngraph ...@@ -165,24 +165,23 @@ namespace ngraph
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
BatchNormTrainingBackprop() = default; BatchNormTrainingBackprop() = default;
BatchNormTrainingBackprop(Output<Node> input, BatchNormTrainingBackprop(const Output<Node>& input,
Output<Node> gamma, const Output<Node>& gamma,
Output<Node> beta, const Output<Node>& beta,
Output<Node> mean, const Output<Node>& mean,
Output<Node> variance, const Output<Node>& variance,
Output<Node> delta, const Output<Node>& delta,
double epsilon); double epsilon);
NGRAPH_DEPRECATED_DOC NGRAPH_DEPRECATED_DOC
NGRAPH_DEPRECATED("Use another constructor") NGRAPH_DEPRECATED("Use another constructor")
BatchNormTrainingBackprop(double epsilon, BatchNormTrainingBackprop(double epsilon,
Output<Node> gamma, const Output<Node>& gamma,
Output<Node> beta, const Output<Node>& beta,
Output<Node> input, const Output<Node>& input,
const Output<Node>& mean,
Output<Node> mean, const Output<Node>& variance,
Output<Node> variance, const Output<Node>& delta);
Output<Node> delta);
void validate_and_infer_types() override; void validate_and_infer_types() override;
......
...@@ -64,7 +64,7 @@ void op::Divide::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto ...@@ -64,7 +64,7 @@ void op::Divide::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto
adjoints.add_delta(y, -delta * shared_from_this() / y); adjoints.add_delta(y, -delta * shared_from_this() / y);
} }
shared_ptr<Node> ngraph::operator/(const Output<Node> arg0, const Output<Node> arg1) shared_ptr<Node> ngraph::operator/(const Output<Node>& arg0, const Output<Node>& arg1)
{ {
return make_shared<op::Divide>(arg0, arg1); return make_shared<op::Divide>(arg0, arg1);
} }
...@@ -64,6 +64,5 @@ namespace ngraph ...@@ -64,6 +64,5 @@ namespace ngraph
}; };
} }
std::shared_ptr<ngraph::Node> operator/(const Output<ngraph::Node> arg0, std::shared_ptr<Node> operator/(const Output<Node>& arg0, const Output<Node>& arg1);
const Output<ngraph::Node> arg1);
} }
...@@ -58,7 +58,7 @@ namespace ngraph ...@@ -58,7 +58,7 @@ namespace ngraph
void validate_and_infer_types() override; void validate_and_infer_types() override;
size_t get_reduction_axes_count() const { return m_reduction_axes_count; } size_t get_reduction_axes_count() const { return m_reduction_axes_count; }
void get_reduction_axes_count(size_t reduction_axes_count) void set_reduction_axes_count(size_t reduction_axes_count)
{ {
m_reduction_axes_count = reduction_axes_count; m_reduction_axes_count = reduction_axes_count;
} }
......
...@@ -56,6 +56,8 @@ namespace ngraph ...@@ -56,6 +56,8 @@ namespace ngraph
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
}; };
} }
} }
...@@ -23,6 +23,8 @@ using namespace ngraph; ...@@ -23,6 +23,8 @@ using namespace ngraph;
static int PARAMS = 0; static int PARAMS = 0;
static int INDICES = 1; static int INDICES = 1;
const string op::Gather::type_name{"Gather"};
shared_ptr<Node> op::Gather::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Gather::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -26,13 +26,15 @@ namespace ngraph ...@@ -26,13 +26,15 @@ namespace ngraph
class Gather : public Op class Gather : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
Gather() = default;
/// \param params The tensor from which slices are gathered /// \param params The tensor from which slices are gathered
/// \param indices Index tensor: Data type must be `element::i32` or `element::i64` /// \param indices Index tensor: Data type must be `element::i32` or `element::i64`
/// \param axis Axis in params to gather /// \param axis Axis in params to gather
Gather(const std::shared_ptr<Node>& params, Gather(const Output<Node>& params, const Output<Node>& indices, size_t axis = 0)
const std::shared_ptr<Node>& indices, : Op({params, indices})
size_t axis = 0)
: Op("Gather", check_single_output_args({params, indices}))
, m_axis(axis) , m_axis(axis)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
...@@ -46,6 +48,7 @@ namespace ngraph ...@@ -46,6 +48,7 @@ namespace ngraph
} }
size_t get_axis() const { return m_axis; } size_t get_axis() const { return m_axis; }
void set_axis(size_t axis) { m_axis = axis; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -23,6 +23,8 @@ using namespace ngraph; ...@@ -23,6 +23,8 @@ using namespace ngraph;
static int PARAMS = 0; static int PARAMS = 0;
static int INDICES = 1; static int INDICES = 1;
const string op::GatherND::type_name{"GatherND"};
shared_ptr<Node> op::GatherND::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::GatherND::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -26,10 +26,14 @@ namespace ngraph ...@@ -26,10 +26,14 @@ namespace ngraph
class GatherND : public Op class GatherND : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
GatherND() = default;
/// \param params The tensor from which slices are gathered /// \param params The tensor from which slices are gathered
/// \param indices Index tensor: Data type must be `element::i32` or `element::i64` /// \param indices Index tensor: Data type must be `element::i32` or `element::i64`
GatherND(const std::shared_ptr<Node>& params, const std::shared_ptr<Node>& indices) GatherND(const Output<Node>& params, const Output<Node>& indices)
: Op("GatherND", check_single_output_args({params, indices})) : Op({params, indices})
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -19,10 +19,12 @@ ...@@ -19,10 +19,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Greater::Greater(const shared_ptr<Node>& arg0, const string op::Greater::type_name{"Greater"};
const shared_ptr<Node>& arg1,
op::Greater::Greater(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob) const AutoBroadcastSpec& autob)
: BinaryElementwiseComparison("Greater", arg0, arg1, autob) : BinaryElementwiseComparison(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,13 +26,18 @@ namespace ngraph ...@@ -26,13 +26,18 @@ namespace ngraph
class Greater : public util::BinaryElementwiseComparison class Greater : public util::BinaryElementwiseComparison
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a greater-than operation.
Greater() = default;
/// \brief Constructs a greater-than operation. /// \brief Constructs a greater-than operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
Greater(const std::shared_ptr<Node>& arg0, Greater(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -19,10 +19,12 @@ ...@@ -19,10 +19,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::GreaterEq::GreaterEq(const shared_ptr<Node>& arg0, const string op::GreaterEq::type_name{"GreaterEq"};
const shared_ptr<Node>& arg1,
op::GreaterEq::GreaterEq(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob) const AutoBroadcastSpec& autob)
: BinaryElementwiseComparison("GreaterEq", arg0, arg1, autob) : BinaryElementwiseComparison(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,13 +26,18 @@ namespace ngraph ...@@ -26,13 +26,18 @@ namespace ngraph
class GreaterEq : public util::BinaryElementwiseComparison class GreaterEq : public util::BinaryElementwiseComparison
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a greater-than-or-equal operation.
GreaterEq() = default;
/// \brief Constructs a greater-than-or-equal operation. /// \brief Constructs a greater-than-or-equal operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
GreaterEq(const std::shared_ptr<Node>& arg0, GreaterEq(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -19,10 +19,10 @@ ...@@ -19,10 +19,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Less::Less(const shared_ptr<Node>& arg0, const string op::Less::type_name{"Less"};
const shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob) op::Less::Less(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& autob)
: BinaryElementwiseComparison("Less", arg0, arg1, autob) : BinaryElementwiseComparison(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,13 +26,18 @@ namespace ngraph ...@@ -26,13 +26,18 @@ namespace ngraph
class Less : public util::BinaryElementwiseComparison class Less : public util::BinaryElementwiseComparison
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a less-than operation.
Less() = default;
/// \brief Constructs a less-than operation. /// \brief Constructs a less-than operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
Less(const std::shared_ptr<Node>& arg0, Less(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -19,10 +19,12 @@ ...@@ -19,10 +19,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::LessEq::LessEq(const shared_ptr<Node>& arg0, const string op::LessEq::type_name{"LessEq"};
const shared_ptr<Node>& arg1,
op::LessEq::LessEq(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob) const AutoBroadcastSpec& autob)
: BinaryElementwiseComparison("LessEq", arg0, arg1, autob) : BinaryElementwiseComparison(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,13 +26,18 @@ namespace ngraph ...@@ -26,13 +26,18 @@ namespace ngraph
class LessEq : public util::BinaryElementwiseComparison class LessEq : public util::BinaryElementwiseComparison
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a less-than-or-equal operation.
LessEq() = default;
/// \brief Constructs a less-than-or-equal operation. /// \brief Constructs a less-than-or-equal operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
LessEq(const std::shared_ptr<Node>& arg0, LessEq(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -20,8 +20,10 @@ ...@@ -20,8 +20,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Log::Log(const shared_ptr<Node>& arg) const string op::Log::type_name{"Log"};
: UnaryElementwiseArithmetic("Log", arg)
op::Log::Log(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,10 +26,15 @@ namespace ngraph ...@@ -26,10 +26,15 @@ namespace ngraph
class Log : public util::UnaryElementwiseArithmetic class Log : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a natural log operation.
Log() = default;
/// \brief Constructs a natural log operation. /// \brief Constructs a natural log operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Log(const std::shared_ptr<Node>& arg); Log(const Output<Node>& arg);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -20,12 +20,14 @@ ...@@ -20,12 +20,14 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::LRN::LRN(const std::shared_ptr<Node>& arg, double alpha, double beta, double bias, size_t nsize) const string op::LRN::type_name{"LRN"};
: UnaryElementwiseArithmetic("LRN", arg)
op::LRN::LRN(const Output<Node>& arg, double alpha, double beta, double bias, size_t size)
: UnaryElementwiseArithmetic(arg)
, m_alpha(alpha) , m_alpha(alpha)
, m_beta(beta) , m_beta(beta)
, m_bias(bias) , m_bias(bias)
, m_size(nsize) , m_size(size)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -38,23 +38,28 @@ namespace ngraph ...@@ -38,23 +38,28 @@ namespace ngraph
class LRN : public util::UnaryElementwiseArithmetic class LRN : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a LRN operation.
LRN() = default;
/// \brief Constructs a LRN operation. /// \brief Constructs a LRN operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
LRN(const std::shared_ptr<Node>& arg, LRN(const Output<Node>& arg, double alpha, double beta, double bias, size_t size);
double alpha,
double beta,
double bias,
size_t size);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
double get_alpha() const { return m_alpha; } double get_alpha() const { return m_alpha; }
void set_alpha(double alpha) { m_alpha = alpha; }
double get_beta() const { return m_beta; } double get_beta() const { return m_beta; }
void set_beta(double beta) { m_beta = beta; }
double get_bias() const { return m_bias; } double get_bias() const { return m_bias; }
void set_bias(double bias) { m_bias = bias; }
size_t get_nsize() const { return m_size; } size_t get_nsize() const { return m_size; }
void set_nsize(size_t size) { m_size = size; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
......
...@@ -22,10 +22,6 @@ using namespace ngraph; ...@@ -22,10 +22,6 @@ using namespace ngraph;
const string op::Max::type_name{"Max"}; const string op::Max::type_name{"Max"};
op::Max::Max()
{
}
op::Max::Max(const Output<Node>& arg, const AxisSet& reduction_axes) op::Max::Max(const Output<Node>& arg, const AxisSet& reduction_axes)
: ArithmeticReduction(arg, reduction_axes) : ArithmeticReduction(arg, reduction_axes)
{ {
......
...@@ -30,7 +30,7 @@ namespace ngraph ...@@ -30,7 +30,7 @@ namespace ngraph
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs a "max" reduction operation. /// \brief Constructs a "max" reduction operation.
Max(); Max() = default;
/// \brief Constructs a max-reduction operation. /// \brief Constructs a max-reduction operation.
/// ///
/// \param arg The tensor to be reduced. /// \param arg The tensor to be reduced.
......
...@@ -25,14 +25,16 @@ ...@@ -25,14 +25,16 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::MaxPool::MaxPool(const shared_ptr<Node>& arg, const string op::MaxPool::type_name{"MaxPool"};
op::MaxPool::MaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above, const Shape& padding_above,
const PadType& pad_type, const PadType& pad_type,
bool ceil_mode) bool ceil_mode)
: Op("MaxPool", check_single_output_args({arg})) : Op({arg})
, m_window_shape(window_shape) , m_window_shape(window_shape)
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
, m_padding_below(padding_below) , m_padding_below(padding_below)
...@@ -43,7 +45,7 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg, ...@@ -43,7 +45,7 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::MaxPool::MaxPool(const shared_ptr<Node>& arg, op::MaxPool::MaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
...@@ -54,7 +56,7 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg, ...@@ -54,7 +56,7 @@ op::MaxPool::MaxPool(const shared_ptr<Node>& arg,
{ {
} }
op::MaxPool::MaxPool(const shared_ptr<Node>& arg, op::MaxPool::MaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
...@@ -121,14 +123,14 @@ void op::MaxPool::validate_and_infer_types() ...@@ -121,14 +123,14 @@ void op::MaxPool::validate_and_infer_types()
m_ceil_mode)); m_ceil_mode));
} }
op::MaxPool::MaxPool(const shared_ptr<Node>& arg, op::MaxPool::MaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides) const Strides& window_movement_strides)
: MaxPool(arg, window_shape, window_movement_strides, Shape(), Shape()) : MaxPool(arg, window_shape, window_movement_strides, Shape(), Shape())
{ {
} }
op::MaxPool::MaxPool(const shared_ptr<Node>& arg, const Shape& window_shape) op::MaxPool::MaxPool(const Output<Node>& arg, const Shape& window_shape)
: MaxPool(arg, window_shape, Strides(), Shape(), Shape()) : MaxPool(arg, window_shape, Strides(), Shape(), Shape())
{ {
} }
...@@ -145,13 +147,15 @@ shared_ptr<Node> op::MaxPool::copy_with_new_args(const NodeVector& new_args) con ...@@ -145,13 +147,15 @@ shared_ptr<Node> op::MaxPool::copy_with_new_args(const NodeVector& new_args) con
m_ceil_mode); m_ceil_mode);
} }
op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward, const string op::MaxPoolBackprop::type_name{"MaxPoolBackprop"};
const shared_ptr<Node>& delta,
op::MaxPoolBackprop::MaxPoolBackprop(const Output<Node>& arg_forward,
const Output<Node>& delta,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above) const Shape& padding_above)
: Op("MaxPoolBackprop", check_single_output_args({arg_forward, delta})) : Op({arg_forward, delta})
, m_window_shape(window_shape) , m_window_shape(window_shape)
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
, m_padding_below(padding_below) , m_padding_below(padding_below)
...@@ -160,14 +164,14 @@ op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward, ...@@ -160,14 +164,14 @@ op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::MaxPoolBackprop::MaxPoolBackprop(const shared_ptr<Node>& arg_forward, op::MaxPoolBackprop::MaxPoolBackprop(const Output<Node>& arg_forward,
const shared_ptr<Node>& delta, const Output<Node>& delta,
const shared_ptr<Node>& result_forward, const Output<Node>& result_forward,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above) const Shape& padding_above)
: Op("MaxPoolBackprop", check_single_output_args({arg_forward, delta, result_forward})) : Op({arg_forward, delta, result_forward})
, m_window_shape(window_shape) , m_window_shape(window_shape)
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
, m_padding_below(padding_below) , m_padding_below(padding_below)
......
...@@ -28,6 +28,12 @@ namespace ngraph ...@@ -28,6 +28,12 @@ namespace ngraph
class MaxPool : public Op class MaxPool : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a batched max pooling operation.
MaxPool() = default;
/// \brief Constructs a batched max pooling operation. /// \brief Constructs a batched max pooling operation.
/// ///
/// \param arg The node producing the input data batch tensor. /// \param arg The node producing the input data batch tensor.
...@@ -37,7 +43,7 @@ namespace ngraph ...@@ -37,7 +43,7 @@ namespace ngraph
/// \param padding_above The above-padding shape. /// \param padding_above The above-padding shape.
/// \param pad_type The pad type for automatically computing padding sizes /// \param pad_type The pad type for automatically computing padding sizes
/// \param ceil_mode Whether to use ceiling while computing output shape. /// \param ceil_mode Whether to use ceiling while computing output shape.
MaxPool(const std::shared_ptr<Node>& arg, MaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
...@@ -53,7 +59,7 @@ namespace ngraph ...@@ -53,7 +59,7 @@ namespace ngraph
/// \param padding_below The below-padding shape. /// \param padding_below The below-padding shape.
/// \param padding_above The above-padding shape. /// \param padding_above The above-padding shape.
/// \param pad_type The pad type for automatically computing padding sizes /// \param pad_type The pad type for automatically computing padding sizes
MaxPool(const std::shared_ptr<Node>& arg, MaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
...@@ -67,7 +73,7 @@ namespace ngraph ...@@ -67,7 +73,7 @@ namespace ngraph
/// \param window_movement_strides The window movement strides. /// \param window_movement_strides The window movement strides.
/// \param padding_below The below-padding shape. /// \param padding_below The below-padding shape.
/// \param padding_above The above-padding shape. /// \param padding_above The above-padding shape.
MaxPool(const std::shared_ptr<Node>& arg, MaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
...@@ -80,7 +86,7 @@ namespace ngraph ...@@ -80,7 +86,7 @@ namespace ngraph
/// \param arg The node producing the input data batch tensor. /// \param arg The node producing the input data batch tensor.
/// \param window_shape The window shape. /// \param window_shape The window shape.
/// \param window_movement_strides The window movement strides. /// \param window_movement_strides The window movement strides.
MaxPool(const std::shared_ptr<Node>& arg, MaxPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides); const Strides& window_movement_strides);
...@@ -88,23 +94,32 @@ namespace ngraph ...@@ -88,23 +94,32 @@ namespace ngraph
/// ///
/// \param arg The node producing the input data batch tensor. /// \param arg The node producing the input data batch tensor.
/// \param window_shape The window shape. /// \param window_shape The window shape.
MaxPool(const std::shared_ptr<Node>& arg, const Shape& window_shape); MaxPool(const Output<Node>& arg, const Shape& window_shape);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
/// \return The window shape. /// \return The window shape.
const Shape& get_window_shape() const { return m_window_shape; } const Shape& get_window_shape() const { return m_window_shape; }
void set_window_shape(const Shape& window_shape) { m_window_shape = window_shape; }
/// \return The window movement strides. /// \return The window movement strides.
const Strides& get_window_movement_strides() const { return m_window_movement_strides; } const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
void set_window_movement_strides(const Strides& window_movement_strides)
{
m_window_movement_strides = window_movement_strides;
}
/// \return The below-padding shape. /// \return The below-padding shape.
const Shape& get_padding_below() const { return m_padding_below; } const Shape& get_padding_below() const { return m_padding_below; }
void set_padding_below(const Shape& padding_below) { m_padding_below = padding_below; }
/// \return The above-padding shape. /// \return The above-padding shape.
const Shape& get_padding_above() const { return m_padding_above; } const Shape& get_padding_above() const { return m_padding_above; }
void set_adding_above(const Shape& padding_above) { m_padding_above = padding_above; }
/// \return The pad type for pooling. /// \return The pad type for pooling.
const PadType& get_pad_type() const { return m_pad_type; } const PadType& get_pad_type() const { return m_pad_type; }
void set_pad_type(const PadType& pad_type) { m_pad_type = pad_type; }
/// \return The ceiling mode being used for output shape computations /// \return The ceiling mode being used for output shape computations
bool get_ceil_mode() const { return m_ceil_mode; } bool get_ceil_mode() const { return m_ceil_mode; }
void set_ceil_mode(bool ceil_mode) { m_ceil_mode = ceil_mode; }
/// \return The default value for MaxPool. /// \return The default value for MaxPool.
virtual std::shared_ptr<Node> get_default_value() const override virtual std::shared_ptr<Node> get_default_value() const override
{ {
...@@ -126,16 +141,21 @@ namespace ngraph ...@@ -126,16 +141,21 @@ namespace ngraph
class MaxPoolBackprop : public Op class MaxPoolBackprop : public Op
{ {
public: public:
MaxPoolBackprop(const std::shared_ptr<Node>& arg_forward, NGRAPH_API
const std::shared_ptr<Node>& delta, static const std::string type_name;
const std::string& description() const override { return type_name; }
MaxPoolBackprop() = default;
MaxPoolBackprop(const Output<Node>& arg_forward,
const Output<Node>& delta,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
const Shape& padding_above); const Shape& padding_above);
MaxPoolBackprop(const std::shared_ptr<Node>& arg_forward, MaxPoolBackprop(const Output<Node>& arg_forward,
const std::shared_ptr<Node>& delta, const Output<Node>& delta,
const std::shared_ptr<Node>& result_forward, const Output<Node>& result_forward,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Shape& padding_below, const Shape& padding_below,
...@@ -147,9 +167,16 @@ namespace ngraph ...@@ -147,9 +167,16 @@ namespace ngraph
void validate_and_infer_types() override; void validate_and_infer_types() override;
const Shape& get_window_shape() const { return m_window_shape; } const Shape& get_window_shape() const { return m_window_shape; }
void set_window_shape(const Shape& window_shape) { m_window_shape = window_shape; }
const Strides& get_window_movement_strides() const { return m_window_movement_strides; } const Strides& get_window_movement_strides() const { return m_window_movement_strides; }
void set_window_movement_strides(const Strides& window_movement_strides)
{
m_window_movement_strides = window_movement_strides;
}
const Shape& get_padding_below() const { return m_padding_below; } const Shape& get_padding_below() const { return m_padding_below; }
void set_padding_below(const Shape& padding_below) { m_padding_below = padding_below; }
const Shape& get_padding_above() const { return m_padding_above; } const Shape& get_padding_above() const { return m_padding_above; }
void set_padding_above(const Shape& padding_above) { m_padding_above = padding_above; }
protected: protected:
Shape m_window_shape; Shape m_window_shape;
Strides m_window_movement_strides; Strides m_window_movement_strides;
......
...@@ -25,10 +25,12 @@ ...@@ -25,10 +25,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Maximum::Maximum(const shared_ptr<Node>& arg0, const string op::Maximum::type_name{"Maximum"};
const shared_ptr<Node>& arg1,
op::Maximum::Maximum(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob) const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic("Maximum", arg0, arg1, autob) : BinaryElementwiseArithmetic(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,19 +26,24 @@ namespace ngraph ...@@ -26,19 +26,24 @@ namespace ngraph
class Maximum : public util::BinaryElementwiseArithmetic class Maximum : public util::BinaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a maximum operation.
Maximum() = default;
/// \brief Constructs a maximum operation. /// \brief Constructs a maximum operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
Maximum(const std::shared_ptr<Node>& arg0, Maximum(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() override { return true; } virtual bool is_commutative() const override { return true; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
......
...@@ -22,10 +22,6 @@ using namespace ngraph; ...@@ -22,10 +22,6 @@ using namespace ngraph;
const string op::Min::type_name{"Min"}; const string op::Min::type_name{"Min"};
op::Min::Min()
{
}
op::Min::Min(const Output<Node>& arg, const AxisSet& reduction_axes) op::Min::Min(const Output<Node>& arg, const AxisSet& reduction_axes)
: ArithmeticReduction(arg, reduction_axes) : ArithmeticReduction(arg, reduction_axes)
{ {
......
...@@ -30,7 +30,7 @@ namespace ngraph ...@@ -30,7 +30,7 @@ namespace ngraph
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs a "min" reduction operation. /// \brief Constructs a "min" reduction operation.
Min(); Min() = default;
/// \brief Constructs a min-reduction operation. /// \brief Constructs a min-reduction operation.
/// ///
/// \param arg The tensor to be reduced. /// \param arg The tensor to be reduced.
......
...@@ -25,10 +25,12 @@ ...@@ -25,10 +25,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Minimum::Minimum(const shared_ptr<Node>& arg0, const string op::Minimum::type_name{"Minimum"};
const shared_ptr<Node>& arg1,
op::Minimum::Minimum(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob) const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic("Minimum", arg0, arg1, autob) : BinaryElementwiseArithmetic(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,18 +26,24 @@ namespace ngraph ...@@ -26,18 +26,24 @@ namespace ngraph
class Minimum : public util::BinaryElementwiseArithmetic class Minimum : public util::BinaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a minimum operation.
Minimum() = default;
/// \brief Constructs a minimum operation. /// \brief Constructs a minimum operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
Minimum(const std::shared_ptr<Node>& arg0, Minimum(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
......
...@@ -19,10 +19,12 @@ ...@@ -19,10 +19,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Multiply::Multiply(const shared_ptr<Node>& arg0, const string op::Multiply::type_name{"Multiply"};
const shared_ptr<Node>& arg1,
op::Multiply::Multiply(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob) const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic("Multiply", arg0, arg1, autob) : BinaryElementwiseArithmetic(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -49,7 +51,7 @@ void op::Multiply::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec ...@@ -49,7 +51,7 @@ void op::Multiply::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec
adjoints.add_delta(y, x * delta); adjoints.add_delta(y, x * delta);
} }
shared_ptr<Node> ngraph::operator*(const shared_ptr<Node> arg0, const shared_ptr<Node> arg1) shared_ptr<Node> ngraph::operator*(const Output<Node>& arg0, const Output<Node>& arg1)
{ {
return make_shared<op::Multiply>(arg0, arg1); return make_shared<op::Multiply>(arg0, arg1);
} }
...@@ -26,25 +26,29 @@ namespace ngraph ...@@ -26,25 +26,29 @@ namespace ngraph
class Multiply : public util::BinaryElementwiseArithmetic class Multiply : public util::BinaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a multiplication operation.
Multiply() = default;
/// \brief Constructs a multiplication operation. /// \brief Constructs a multiplication operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
Multiply(const std::shared_ptr<Node>& arg0, Multiply(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
virtual bool is_commutative() override { return true; }
}; };
}; };
std::shared_ptr<ngraph::Node> operator*(const std::shared_ptr<ngraph::Node> arg0, std::shared_ptr<Node> operator*(const Output<Node>& arg0, const Output<Node>& arg1);
const std::shared_ptr<ngraph::Node> arg1);
} }
...@@ -19,8 +19,10 @@ ...@@ -19,8 +19,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Negative::Negative(const shared_ptr<Node>& arg) const string op::Negative::type_name{"Negative"};
: UnaryElementwiseArithmetic("Negative", arg)
op::Negative::Negative(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -40,7 +42,7 @@ void op::Negative::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec ...@@ -40,7 +42,7 @@ void op::Negative::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec
adjoints.add_delta(x, -delta); adjoints.add_delta(x, -delta);
} }
shared_ptr<Node> ngraph::operator-(const shared_ptr<Node> arg0) shared_ptr<Node> ngraph::operator-(const Output<Node>& arg0)
{ {
return make_shared<op::Negative>(arg0); return make_shared<op::Negative>(arg0);
} }
...@@ -26,17 +26,23 @@ namespace ngraph ...@@ -26,17 +26,23 @@ namespace ngraph
class Negative : public util::UnaryElementwiseArithmetic class Negative : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a negative operation.
Negative() = default;
/// \brief Constructs a negative operation. /// \brief Constructs a negative operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Negative(const std::shared_ptr<Node>& arg); Negative(const Output<Node>& arg);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
}; };
} }
std::shared_ptr<ngraph::Node> operator-(const std::shared_ptr<ngraph::Node> arg0); std::shared_ptr<Node> operator-(const Output<Node>& arg0);
} }
...@@ -20,8 +20,10 @@ ...@@ -20,8 +20,10 @@
using namespace ngraph; using namespace ngraph;
using namespace std; using namespace std;
op::Not::Not(const shared_ptr<Node>& arg) const string op::Not::type_name{"Not"};
: Op("Not", check_single_output_args({arg}))
op::Not::Not(const Output<Node>& arg)
: Op({arg})
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,10 +26,15 @@ namespace ngraph ...@@ -26,10 +26,15 @@ namespace ngraph
class Not : public Op class Not : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a logical negation operation.
Not() = default;
/// \brief Constructs a logical negation operation. /// \brief Constructs a logical negation operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Not(const std::shared_ptr<Node>& arg); Not(const Output<Node>& arg);
void validate_and_infer_types() override; void validate_and_infer_types() override;
......
...@@ -19,10 +19,12 @@ ...@@ -19,10 +19,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::NotEqual::NotEqual(const shared_ptr<Node>& arg0, const string op::NotEqual::type_name{"NotEqual"};
const shared_ptr<Node>& arg1,
op::NotEqual::NotEqual(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob) const AutoBroadcastSpec& autob)
: BinaryElementwiseComparison("NotEqual", arg0, arg1, autob) : BinaryElementwiseComparison(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,17 +26,24 @@ namespace ngraph ...@@ -26,17 +26,24 @@ namespace ngraph
class NotEqual : public util::BinaryElementwiseComparison class NotEqual : public util::BinaryElementwiseComparison
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a not-equal operation.
NotEqual() = default;
/// \brief Constructs a not-equal operation. /// \brief Constructs a not-equal operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
NotEqual(const std::shared_ptr<Node>& arg0, NotEqual(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
}; };
} }
} }
...@@ -20,8 +20,10 @@ ...@@ -20,8 +20,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::OneHot::OneHot(const shared_ptr<Node>& arg, const PartialShape& shape, size_t one_hot_axis) const string op::OneHot::type_name{"OneHot"};
: Op("OneHot", check_single_output_args({arg}))
op::OneHot::OneHot(const Output<Node>& arg, const PartialShape& shape, size_t one_hot_axis)
: Op({arg})
, m_shape(shape) , m_shape(shape)
, m_one_hot_axis(one_hot_axis) , m_one_hot_axis(one_hot_axis)
{ {
......
...@@ -45,14 +45,17 @@ namespace ngraph ...@@ -45,14 +45,17 @@ namespace ngraph
class OneHot : public Op class OneHot : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a one-hot operation.
OneHot() = default;
/// \brief Constructs a one-hot operation. /// \brief Constructs a one-hot operation.
/// ///
/// \param arg Node that produces the input tensor to be one-hot encoded. /// \param arg Node that produces the input tensor to be one-hot encoded.
/// \param shape The shape of the output tensor, including the new one-hot axis. /// \param shape The shape of the output tensor, including the new one-hot axis.
/// \param one_hot_axis The index within the output shape of the new one-hot axis. /// \param one_hot_axis The index within the output shape of the new one-hot axis.
OneHot(const std::shared_ptr<Node>& arg, OneHot(const Output<Node>& arg, const PartialShape& shape, size_t one_hot_axis);
const PartialShape& shape,
size_t one_hot_axis);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -60,6 +63,7 @@ namespace ngraph ...@@ -60,6 +63,7 @@ namespace ngraph
/// \return The index of the one-hot axis. /// \return The index of the one-hot axis.
size_t get_one_hot_axis() const { return m_one_hot_axis; } size_t get_one_hot_axis() const { return m_one_hot_axis; }
void set_one_hot_axis(size_t one_hot_axis) { m_one_hot_axis = one_hot_axis; }
protected: protected:
PartialShape m_shape; PartialShape m_shape;
size_t m_one_hot_axis; size_t m_one_hot_axis;
......
...@@ -19,10 +19,10 @@ ...@@ -19,10 +19,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Or::Or(const shared_ptr<Node>& arg0, const string op::Or::type_name{"Or"};
const shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob) op::Or::Or(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& autob)
: BinaryElementwiseLogical("Or", arg0, arg1, autob) : BinaryElementwiseLogical(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -29,6 +29,9 @@ namespace ngraph ...@@ -29,6 +29,9 @@ namespace ngraph
class Or : public util::BinaryElementwiseLogical class Or : public util::BinaryElementwiseLogical
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a logical-or operation. /// \brief Constructs a logical-or operation.
/// ///
/// \param arg0 Node that produces the first input tensor.<br> /// \param arg0 Node that produces the first input tensor.<br>
...@@ -39,15 +42,14 @@ namespace ngraph ...@@ -39,15 +42,14 @@ namespace ngraph
/// ///
/// Output `[d0, ...]` /// Output `[d0, ...]`
/// ///
Or(const std::shared_ptr<Node>& arg0, Or(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
protected: virtual bool is_commutative() const override { return true; }
virtual bool is_commutative() override { return true; }
}; };
} }
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment