Unverified Commit 0552fd84 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Convert to new op form (#3112)

parent 5e7aacf1
...@@ -559,6 +559,16 @@ const NodeVector& ngraph::check_single_output_args(const NodeVector& args) ...@@ -559,6 +559,16 @@ const NodeVector& ngraph::check_single_output_args(const NodeVector& args)
return args; return args;
} }
OutputVector ngraph::as_output_vector(const NodeVector& args)
{
OutputVector output_vector;
for (auto& arg : check_single_output_args(args))
{
output_vector.push_back(arg);
}
return output_vector;
}
std::tuple<element::Type, PartialShape> std::tuple<element::Type, PartialShape>
Node::validate_and_infer_elementwise_args(const op::AutoBroadcastSpec& autob) Node::validate_and_infer_elementwise_args(const op::AutoBroadcastSpec& autob)
{ {
......
...@@ -73,6 +73,8 @@ namespace ngraph ...@@ -73,6 +73,8 @@ namespace ngraph
size_t i); size_t i);
const NodeVector& check_single_output_args(const NodeVector& args); const NodeVector& check_single_output_args(const NodeVector& args);
OutputVector as_output_vector(const NodeVector& args);
/// Alias useful for cloning /// Alias useful for cloning
using NodeMap = std::unordered_map<ngraph::Node*, std::shared_ptr<ngraph::Node>>; using NodeMap = std::unordered_map<ngraph::Node*, std::shared_ptr<ngraph::Node>>;
......
...@@ -23,10 +23,6 @@ using namespace ngraph; ...@@ -23,10 +23,6 @@ using namespace ngraph;
const string op::Abs::type_name{"Abs"}; const string op::Abs::type_name{"Abs"};
op::Abs::Abs()
{
}
op::Abs::Abs(const Output<Node>& arg) op::Abs::Abs(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg) : UnaryElementwiseArithmetic(arg)
{ {
......
...@@ -33,7 +33,7 @@ namespace ngraph ...@@ -33,7 +33,7 @@ namespace ngraph
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs an absolute value operation. /// \brief Constructs an absolute value operation.
Abs(); Abs() = default;
/// \brief Constructs an absolute value operation. /// \brief Constructs an absolute value operation.
/// ///
......
...@@ -34,10 +34,6 @@ using namespace ngraph; ...@@ -34,10 +34,6 @@ using namespace ngraph;
const string op::Acos::type_name{"Acos"}; const string op::Acos::type_name{"Acos"};
op::Acos::Acos()
{
}
op::Acos::Acos(const Output<Node>& arg) op::Acos::Acos(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg) : UnaryElementwiseArithmetic(arg)
{ {
......
...@@ -33,7 +33,7 @@ namespace ngraph ...@@ -33,7 +33,7 @@ namespace ngraph
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs an arccos operation. /// \brief Constructs an arccos operation.
Acos(); Acos() = default;
/// \brief Constructs an arccos operation. /// \brief Constructs an arccos operation.
/// ///
/// \param arg Output that produces the input tensor.<br> /// \param arg Output that produces the input tensor.<br>
......
...@@ -21,10 +21,6 @@ using namespace ngraph; ...@@ -21,10 +21,6 @@ using namespace ngraph;
const string op::Add::type_name{"Add"}; const string op::Add::type_name{"Add"};
op::Add::Add()
{
}
op::Add::Add(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& autob) op::Add::Add(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic(arg0, arg1, autob) : BinaryElementwiseArithmetic(arg0, arg1, autob)
{ {
......
...@@ -33,7 +33,7 @@ namespace ngraph ...@@ -33,7 +33,7 @@ namespace ngraph
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs an unitialized addition operation /// \brief Constructs an unitialized addition operation
Add(); Add() = default;
/// \brief Constructs an addition operation. /// \brief Constructs an addition operation.
/// ///
......
...@@ -21,10 +21,6 @@ using namespace ngraph; ...@@ -21,10 +21,6 @@ using namespace ngraph;
const string op::All::type_name{"All"}; const string op::All::type_name{"All"};
op::All::All()
{
}
op::All::All(const Output<Node>& arg, const AxisSet& reduction_axes) op::All::All(const Output<Node>& arg, const AxisSet& reduction_axes)
: LogicalReduction(arg, reduction_axes) : LogicalReduction(arg, reduction_axes)
{ {
......
...@@ -33,7 +33,7 @@ namespace ngraph ...@@ -33,7 +33,7 @@ namespace ngraph
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs an "all" reduction operation. /// \brief Constructs an "all" reduction operation.
All(); All() = default;
/// \brief Constructs an "all" reduction operation. /// \brief Constructs an "all" reduction operation.
/// ///
/// \param arg The tensor to be reduced. /// \param arg The tensor to be reduced.
......
...@@ -21,10 +21,6 @@ using namespace ngraph; ...@@ -21,10 +21,6 @@ using namespace ngraph;
const string op::And::type_name{"And"}; const string op::And::type_name{"And"};
op::And::And()
{
}
op::And::And(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& autob) op::And::And(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& autob)
: BinaryElementwiseLogical(arg0, arg1, autob) : BinaryElementwiseLogical(arg0, arg1, autob)
{ {
......
...@@ -33,7 +33,7 @@ namespace ngraph ...@@ -33,7 +33,7 @@ namespace ngraph
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs a logical-and operation. /// \brief Constructs a logical-and operation.
And(); And() = default;
/// \brief Constructs a logical-and operation. /// \brief Constructs a logical-and operation.
/// ///
......
...@@ -21,10 +21,6 @@ using namespace ngraph; ...@@ -21,10 +21,6 @@ using namespace ngraph;
const string op::Any::type_name{"Any"}; const string op::Any::type_name{"Any"};
op::Any::Any()
{
}
op::Any::Any(const Output<Node>& arg, const AxisSet& reduction_axes) op::Any::Any(const Output<Node>& arg, const AxisSet& reduction_axes)
: LogicalReduction(arg, reduction_axes) : LogicalReduction(arg, reduction_axes)
{ {
......
...@@ -33,7 +33,7 @@ namespace ngraph ...@@ -33,7 +33,7 @@ namespace ngraph
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs an "any" reduction operation. /// \brief Constructs an "any" reduction operation.
Any(); Any() = default;
/// \brief Constructs an "any" reduction operation. /// \brief Constructs an "any" reduction operation.
/// ///
/// \param arg The tensor to be reduced. /// \param arg The tensor to be reduced.
......
...@@ -21,10 +21,6 @@ using namespace ngraph; ...@@ -21,10 +21,6 @@ using namespace ngraph;
const string op::ArgMax::type_name{"ArgMax"}; const string op::ArgMax::type_name{"ArgMax"};
op::ArgMax::ArgMax()
{
}
op::ArgMax::ArgMax(const Output<Node>& arg, size_t axis, const element::Type& index_element_type) op::ArgMax::ArgMax(const Output<Node>& arg, size_t axis, const element::Type& index_element_type)
: op::util::IndexReduction(arg, axis, index_element_type) : op::util::IndexReduction(arg, axis, index_element_type)
{ {
......
...@@ -32,7 +32,7 @@ namespace ngraph ...@@ -32,7 +32,7 @@ namespace ngraph
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs a ArgMax operation. /// \brief Constructs a ArgMax operation.
ArgMax(); ArgMax() = default;
/// \brief Constructs a ArgMax operation. /// \brief Constructs a ArgMax operation.
/// ///
/// \param arg The input tensor /// \param arg The input tensor
......
...@@ -21,10 +21,6 @@ using namespace ngraph; ...@@ -21,10 +21,6 @@ using namespace ngraph;
const string op::ArgMin::type_name{"ArgMin"}; const string op::ArgMin::type_name{"ArgMin"};
op::ArgMin::ArgMin()
{
}
op::ArgMin::ArgMin(const Output<Node>& arg, size_t axis, const element::Type& index_element_type) op::ArgMin::ArgMin(const Output<Node>& arg, size_t axis, const element::Type& index_element_type)
: op::util::IndexReduction(arg, axis, index_element_type) : op::util::IndexReduction(arg, axis, index_element_type)
{ {
......
...@@ -32,7 +32,7 @@ namespace ngraph ...@@ -32,7 +32,7 @@ namespace ngraph
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs a ArgMin operation. /// \brief Constructs a ArgMin operation.
ArgMin(); ArgMin() = default;
/// \brief Constructs a ArgMin operation. /// \brief Constructs a ArgMin operation.
/// ///
......
...@@ -33,10 +33,6 @@ using namespace ngraph; ...@@ -33,10 +33,6 @@ using namespace ngraph;
const string op::Asin::type_name{"Asin"}; const string op::Asin::type_name{"Asin"};
op::Asin::Asin()
{
}
op::Asin::Asin(const Output<Node>& arg) op::Asin::Asin(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg) : UnaryElementwiseArithmetic(arg)
{ {
......
...@@ -33,7 +33,7 @@ namespace ngraph ...@@ -33,7 +33,7 @@ namespace ngraph
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs an arcsin operation. /// \brief Constructs an arcsin operation.
Asin(); Asin() = default;
/// \brief Constructs an arcsin operation. /// \brief Constructs an arcsin operation.
/// ///
/// \param arg Output that produces the input tensor.<br> /// \param arg Output that produces the input tensor.<br>
......
...@@ -32,10 +32,6 @@ using namespace ngraph; ...@@ -32,10 +32,6 @@ using namespace ngraph;
const string op::Atan::type_name{"Atan"}; const string op::Atan::type_name{"Atan"};
op::Atan::Atan()
{
}
op::Atan::Atan(const Output<Node>& arg) op::Atan::Atan(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg) : UnaryElementwiseArithmetic(arg)
{ {
......
...@@ -33,7 +33,7 @@ namespace ngraph ...@@ -33,7 +33,7 @@ namespace ngraph
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs an arctan operation. /// \brief Constructs an arctan operation.
Atan(); Atan() = default;
/// \brief Constructs an arctan operation. /// \brief Constructs an arctan operation.
/// ///
......
...@@ -23,10 +23,6 @@ using namespace ngraph; ...@@ -23,10 +23,6 @@ using namespace ngraph;
const string op::AvgPool::type_name{"AvgPool"}; const string op::AvgPool::type_name{"AvgPool"};
op::AvgPool::AvgPool()
{
}
op::AvgPool::AvgPool(const Output<Node>& arg, op::AvgPool::AvgPool(const Output<Node>& arg,
const Shape& window_shape, const Shape& window_shape,
const Strides& window_movement_strides, const Strides& window_movement_strides,
...@@ -231,10 +227,6 @@ shared_ptr<Node> op::AvgPool::copy_with_new_args(const NodeVector& new_args) con ...@@ -231,10 +227,6 @@ shared_ptr<Node> op::AvgPool::copy_with_new_args(const NodeVector& new_args) con
const string op::AvgPoolBackprop::type_name("AvgPoolBackprop"); const string op::AvgPoolBackprop::type_name("AvgPoolBackprop");
op::AvgPoolBackprop::AvgPoolBackprop()
{
}
op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape, op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape,
const shared_ptr<Node>& delta, const shared_ptr<Node>& delta,
const Shape& window_shape, const Shape& window_shape,
......
...@@ -33,7 +33,7 @@ namespace ngraph ...@@ -33,7 +33,7 @@ namespace ngraph
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
/// \brief Constructs a batched average pooling operation. /// \brief Constructs a batched average pooling operation.
AvgPool(); AvgPool() = default;
/// \brief Constructs a batched average pooling operation. /// \brief Constructs a batched average pooling operation.
/// ///
...@@ -175,7 +175,7 @@ namespace ngraph ...@@ -175,7 +175,7 @@ namespace ngraph
public: public:
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
AvgPoolBackprop(); AvgPoolBackprop() = default;
AvgPoolBackprop(const Shape& forward_arg_shape, AvgPoolBackprop(const Shape& forward_arg_shape,
const std::shared_ptr<Node>& delta, const std::shared_ptr<Node>& delta,
const Shape& window_shape, const Shape& window_shape,
......
...@@ -22,11 +22,13 @@ ...@@ -22,11 +22,13 @@
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/validation_util.hpp" #include "ngraph/validation_util.hpp"
ngraph::op::BatchNormTraining::BatchNormTraining(std::shared_ptr<ngraph::Node> input, const std::string ngraph::op::BatchNormTraining::type_name{"BatchNormTraining"};
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta, ngraph::op::BatchNormTraining::BatchNormTraining(Output<ngraph::Node> input,
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta,
double epsilon) double epsilon)
: Op("BatchNormTraining", check_single_output_args({gamma, beta, input})) : Op({gamma, beta, input})
, m_epsilon(epsilon) , m_epsilon(epsilon)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
...@@ -34,10 +36,10 @@ ngraph::op::BatchNormTraining::BatchNormTraining(std::shared_ptr<ngraph::Node> i ...@@ -34,10 +36,10 @@ ngraph::op::BatchNormTraining::BatchNormTraining(std::shared_ptr<ngraph::Node> i
// DEPRECATED // DEPRECATED
ngraph::op::BatchNormTraining::BatchNormTraining(double eps, ngraph::op::BatchNormTraining::BatchNormTraining(double eps,
std::shared_ptr<ngraph::Node> gamma, Output<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta, Output<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input) Output<ngraph::Node> input)
: Op("BatchNormTraining", check_single_output_args({gamma, beta, input})) : Op({gamma, beta, input})
, m_epsilon(eps) , m_epsilon(eps)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
...@@ -111,13 +113,15 @@ void ngraph::op::BatchNormTraining::generate_adjoints(autodiff::Adjoints& adjoin ...@@ -111,13 +113,15 @@ void ngraph::op::BatchNormTraining::generate_adjoints(autodiff::Adjoints& adjoin
adjoints.add_delta(beta, dbeta); adjoints.add_delta(beta, dbeta);
} }
ngraph::op::BatchNormInference::BatchNormInference(std::shared_ptr<ngraph::Node> input, const std::string ngraph::op::BatchNormInference::type_name{"BatchNormInference"};
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta, ngraph::op::BatchNormInference::BatchNormInference(Output<ngraph::Node> input,
std::shared_ptr<ngraph::Node> mean, Output<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> variance, Output<ngraph::Node> beta,
Output<ngraph::Node> mean,
Output<ngraph::Node> variance,
double epsilon) double epsilon)
: Op("BatchNormInference", check_single_output_args({gamma, beta, input, mean, variance})) : Op({gamma, beta, input, mean, variance})
, m_epsilon(epsilon) , m_epsilon(epsilon)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
...@@ -125,12 +129,12 @@ ngraph::op::BatchNormInference::BatchNormInference(std::shared_ptr<ngraph::Node> ...@@ -125,12 +129,12 @@ ngraph::op::BatchNormInference::BatchNormInference(std::shared_ptr<ngraph::Node>
// DEPRECATED // DEPRECATED
ngraph::op::BatchNormInference::BatchNormInference(double eps, ngraph::op::BatchNormInference::BatchNormInference(double eps,
std::shared_ptr<ngraph::Node> gamma, Output<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta, Output<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input, Output<ngraph::Node> input,
std::shared_ptr<ngraph::Node> mean, Output<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance) Output<ngraph::Node> variance)
: Op("BatchNormInference", check_single_output_args({gamma, beta, input, mean, variance})) : Op({gamma, beta, input, mean, variance})
, m_epsilon(eps) , m_epsilon(eps)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
...@@ -167,16 +171,16 @@ std::shared_ptr<ngraph::Node> ...@@ -167,16 +171,16 @@ std::shared_ptr<ngraph::Node>
new_args.at(2), new_args.at(0), new_args.at(1), new_args.at(3), new_args.at(4), m_epsilon); new_args.at(2), new_args.at(0), new_args.at(1), new_args.at(3), new_args.at(4), m_epsilon);
} }
ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop( const std::string ngraph::op::BatchNormTrainingBackprop::type_name{"BatchNormTrainingBackprop"};
std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> gamma, ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(Output<ngraph::Node> input,
std::shared_ptr<ngraph::Node> beta, Output<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> mean, Output<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> variance, Output<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> delta, Output<ngraph::Node> variance,
double epsilon) Output<ngraph::Node> delta,
: Op("BatchNormTrainingBackprop", double epsilon)
check_single_output_args({gamma, beta, input, mean, variance, delta})) : Op({gamma, beta, input, mean, variance, delta})
, m_epsilon(epsilon) , m_epsilon(epsilon)
{ {
...@@ -184,16 +188,14 @@ ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop( ...@@ -184,16 +188,14 @@ ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop( ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(double epsilon,
double epsilon, Output<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> gamma, Output<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> beta, Output<ngraph::Node> input,
std::shared_ptr<ngraph::Node> input, Output<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> mean, Output<ngraph::Node> variance,
std::shared_ptr<ngraph::Node> variance, Output<ngraph::Node> delta)
std::shared_ptr<ngraph::Node> delta) : Op({gamma, beta, input, mean, variance, delta})
: Op("BatchNormTrainingBackprop",
check_single_output_args({gamma, beta, input, mean, variance, delta}))
, m_epsilon(epsilon) , m_epsilon(epsilon)
{ {
......
...@@ -31,13 +31,17 @@ namespace ngraph ...@@ -31,13 +31,17 @@ namespace ngraph
class BatchNormTraining : public Op class BatchNormTraining : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
BatchNormTraining() = default;
/// \param input Must have rank >= 2, [., C, ...] /// \param input Must have rank >= 2, [., C, ...]
/// \param gamma gamma scaling for normalized value. [C] /// \param gamma gamma scaling for normalized value. [C]
/// \param beta bias added to the scaled normalized value [C] /// \param beta bias added to the scaled normalized value [C]
/// \param epsilon Avoids divsion by 0 if input has 0 variance /// \param epsilon Avoids divsion by 0 if input has 0 variance
BatchNormTraining(std::shared_ptr<Node> input, BatchNormTraining(Output<Node> input,
std::shared_ptr<Node> gamma, Output<Node> gamma,
std::shared_ptr<Node> beta, Output<Node> beta,
double epsilon); double epsilon);
NGRAPH_DEPRECATED_DOC NGRAPH_DEPRECATED_DOC
...@@ -62,13 +66,14 @@ namespace ngraph ...@@ -62,13 +66,14 @@ namespace ngraph
/// output[2]: shall have rank 1, with the same span as input's channel axis. /// output[2]: shall have rank 1, with the same span as input's channel axis.
NGRAPH_DEPRECATED("Use another constructor") NGRAPH_DEPRECATED("Use another constructor")
BatchNormTraining(double eps, BatchNormTraining(double eps,
std::shared_ptr<Node> gamma, Output<Node> gamma,
std::shared_ptr<Node> beta, Output<Node> beta,
std::shared_ptr<Node> input); Output<Node> input);
void validate_and_infer_types() override; void validate_and_infer_types() override;
double get_eps_value() const { return m_epsilon; } double get_eps_value() const { return m_epsilon; }
void set_eps_value(double epsilon) { m_epsilon = epsilon; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -87,17 +92,20 @@ namespace ngraph ...@@ -87,17 +92,20 @@ namespace ngraph
class BatchNormInference : public Op class BatchNormInference : public Op
{ {
public: public:
static const std::string type_name;
const std::string& description() const override { return type_name; }
BatchNormInference() = default;
/// \param input [., C, ...] /// \param input [., C, ...]
/// \param gamma gamma scaling for normalized value. [C] /// \param gamma gamma scaling for normalized value. [C]
/// \param beta bias added to the scaled normalized value [C] /// \param beta bias added to the scaled normalized value [C]
/// \param mean value for mean normalization [C] /// \param mean value for mean normalization [C]
/// \param variance value for variance normalization [C] /// \param variance value for variance normalization [C]
/// \param epsilon Avoids divsion by 0 if input has 0 variance /// \param epsilon Avoids divsion by 0 if input has 0 variance
BatchNormInference(std::shared_ptr<ngraph::Node> input, BatchNormInference(Output<ngraph::Node> input,
std::shared_ptr<ngraph::Node> gamma, Output<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta, Output<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> mean, Output<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance, Output<ngraph::Node> variance,
double epsilon); double epsilon);
NGRAPH_DEPRECATED_DOC NGRAPH_DEPRECATED_DOC
...@@ -120,15 +128,16 @@ namespace ngraph ...@@ -120,15 +128,16 @@ namespace ngraph
/// output: shall have the same shape as 'input'. /// output: shall have the same shape as 'input'.
NGRAPH_DEPRECATED("Use another constructor") NGRAPH_DEPRECATED("Use another constructor")
BatchNormInference(double eps, BatchNormInference(double eps,
std::shared_ptr<ngraph::Node> gamma, Output<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta, Output<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input, Output<ngraph::Node> input,
std::shared_ptr<ngraph::Node> mean, Output<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance); Output<ngraph::Node> variance);
void validate_and_infer_types() override; void validate_and_infer_types() override;
double get_eps_value() const { return m_epsilon; } double get_eps_value() const { return m_epsilon; }
void set_eps_value(double epsilon) { m_epsilon = epsilon; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -152,28 +161,33 @@ namespace ngraph ...@@ -152,28 +161,33 @@ namespace ngraph
class BatchNormTrainingBackprop : public Op class BatchNormTrainingBackprop : public Op
{ {
public: public:
BatchNormTrainingBackprop(std::shared_ptr<Node> input, NGRAPH_API
std::shared_ptr<Node> gamma, static const std::string type_name;
std::shared_ptr<Node> beta, const std::string& description() const override { return type_name; }
std::shared_ptr<Node> mean, BatchNormTrainingBackprop() = default;
std::shared_ptr<Node> variance, BatchNormTrainingBackprop(Output<Node> input,
std::shared_ptr<Node> delta, Output<Node> gamma,
Output<Node> beta,
Output<Node> mean,
Output<Node> variance,
Output<Node> delta,
double epsilon); double epsilon);
NGRAPH_DEPRECATED_DOC NGRAPH_DEPRECATED_DOC
NGRAPH_DEPRECATED("Use another constructor") NGRAPH_DEPRECATED("Use another constructor")
BatchNormTrainingBackprop(double epsilon, BatchNormTrainingBackprop(double epsilon,
std::shared_ptr<Node> gamma, Output<Node> gamma,
std::shared_ptr<Node> beta, Output<Node> beta,
std::shared_ptr<Node> input, Output<Node> input,
std::shared_ptr<Node> mean, Output<Node> mean,
std::shared_ptr<Node> variance, Output<Node> variance,
std::shared_ptr<Node> delta); Output<Node> delta);
void validate_and_infer_types() override; void validate_and_infer_types() override;
double get_eps_value() const { return m_epsilon; } double get_eps_value() const { return m_epsilon; }
void set_eps_value(double epsilon) { m_epsilon = epsilon; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -20,21 +20,20 @@ ...@@ -20,21 +20,20 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Broadcast::Broadcast(const std::string& name, const string op::Broadcast::type_name{"Broadcast"};
const NodeVector& args,
op::Broadcast::Broadcast(const OutputVector& args,
const Shape& shape, const Shape& shape,
const AxisSet& broadcast_axes) const AxisSet& broadcast_axes)
: Op(name, check_single_output_args(args)) : Op(args)
, m_shape(shape) , m_shape(shape)
, m_broadcast_axes(broadcast_axes) , m_broadcast_axes(broadcast_axes)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::Broadcast::Broadcast(const shared_ptr<Node>& arg, op::Broadcast::Broadcast(const Output<Node>& arg, const Shape& shape, const AxisSet& broadcast_axes)
const Shape& shape, : Broadcast(OutputVector{arg}, shape, broadcast_axes)
const AxisSet& broadcast_axes)
: Broadcast("Broadcast", {arg}, shape, broadcast_axes)
{ {
} }
...@@ -96,10 +95,12 @@ void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVe ...@@ -96,10 +95,12 @@ void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVe
adjoints.add_delta(x, make_shared<op::Sum>(delta, m_broadcast_axes)); adjoints.add_delta(x, make_shared<op::Sum>(delta, m_broadcast_axes));
} }
op::BroadcastLike::BroadcastLike(const std::shared_ptr<Node>& arg, const string op::BroadcastLike::type_name{"BroadcastLike"};
const std::shared_ptr<Node>& like_arg,
op::BroadcastLike::BroadcastLike(const Output<Node>& arg,
const Output<Node>& like_arg,
const AxisSet& initial_broadcast_axes) const AxisSet& initial_broadcast_axes)
: Broadcast("BroadcastLike", {arg, like_arg}, {}, {}) : Broadcast({arg, like_arg}, {}, {})
, m_initial_broadcast_axes(initial_broadcast_axes) , m_initial_broadcast_axes(initial_broadcast_axes)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
......
...@@ -27,15 +27,18 @@ namespace ngraph ...@@ -27,15 +27,18 @@ namespace ngraph
class Broadcast : public Op class Broadcast : public Op
{ {
public: public:
/// \brief Constructs a conversion operation. NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a broadcast operation.
Broadcast() = default;
/// \brief Constructs a broadcast operation.
/// ///
/// \param arg Node that produces the input tensor to be broadcast. /// \param arg Node that produces the input tensor to be broadcast.
/// \param shape The shape of the output tensor. /// \param shape The shape of the output tensor.
/// \param broadcast_axes The axis positions (0-based) in the result that are being broadcast. The /// \param broadcast_axes The axis positions (0-based) in the result that are being broadcast. The
/// remaining axes in shape must be the same as the shape of arg. /// remaining axes in shape must be the same as the shape of arg.
Broadcast(const std::shared_ptr<Node>& arg, Broadcast(const Output<Node>& arg, const Shape& shape, const AxisSet& broadcast_axes);
const Shape& shape,
const AxisSet& broadcast_axes);
void validate_and_infer_types() override; void validate_and_infer_types() override;
...@@ -44,12 +47,14 @@ namespace ngraph ...@@ -44,12 +47,14 @@ namespace ngraph
/// \return A set containing the indices of the broadcast axes (0-based). /// \return A set containing the indices of the broadcast axes (0-based).
const AxisSet& get_broadcast_axes() const { return m_broadcast_axes; } const AxisSet& get_broadcast_axes() const { return m_broadcast_axes; }
void set_broadcast_axes(const AxisSet& broadcast_axes)
{
m_broadcast_axes = broadcast_axes;
}
const Shape& get_broadcast_shape() const { return m_shape; } const Shape& get_broadcast_shape() const { return m_shape; }
void set_broadcast_shape(const Shape& shape) { m_shape = shape; }
protected: protected:
Broadcast(const std::string& node_type, Broadcast(const OutputVector& args, const Shape& shape, const AxisSet& broadcast_axes);
const NodeVector& args,
const Shape& shape,
const AxisSet& broadcast_axes);
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
...@@ -63,6 +68,11 @@ namespace ngraph ...@@ -63,6 +68,11 @@ namespace ngraph
class BroadcastLike : public Broadcast class BroadcastLike : public Broadcast
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Broadcast arg to the same shape as like_arg.
BroadcastLike() = default;
/// \brief Broadcast arg to the same shape as like_arg. /// \brief Broadcast arg to the same shape as like_arg.
/// ///
/// Once the shape of like_arg is known, this op will be replaced with an equivalent /// Once the shape of like_arg is known, this op will be replaced with an equivalent
...@@ -72,8 +82,8 @@ namespace ngraph ...@@ -72,8 +82,8 @@ namespace ngraph
/// \param like_arg Provides the shape for the result. /// \param like_arg Provides the shape for the result.
/// \param initial_broadcast_axes indicates which axes will be broadcast. If empty, /// \param initial_broadcast_axes indicates which axes will be broadcast. If empty,
/// arg must be scalar and all axes are broadcast. /// arg must be scalar and all axes are broadcast.
BroadcastLike(const std::shared_ptr<Node>& arg, BroadcastLike(const Output<Node>& arg,
const std::shared_ptr<Node>& like_arg, const Output<Node>& like_arg,
const AxisSet& initial_broadcast_axes); const AxisSet& initial_broadcast_axes);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
...@@ -81,6 +91,11 @@ namespace ngraph ...@@ -81,6 +91,11 @@ namespace ngraph
void infer_shape() override; void infer_shape() override;
const AxisSet& get_initial_broadcast_axes() const { return m_initial_broadcast_axes; } const AxisSet& get_initial_broadcast_axes() const { return m_initial_broadcast_axes; }
void set_initial_broadcast_axes(const AxisSet& initial_broadcast_axes)
{
m_initial_broadcast_axes = initial_broadcast_axes;
}
protected: protected:
AxisSet m_initial_broadcast_axes; AxisSet m_initial_broadcast_axes;
}; };
......
...@@ -19,8 +19,10 @@ ...@@ -19,8 +19,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::BroadcastDistributed::BroadcastDistributed(const shared_ptr<Node>& arg, int root_id) const string op::BroadcastDistributed::type_name{"BroadcastDistributed"};
: Op("BroadcastDistributed", check_single_output_args({arg}))
op::BroadcastDistributed::BroadcastDistributed(const Output<Node>& arg, int root_id)
: Op({arg})
, m_root_id(root_id) , m_root_id(root_id)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
...@@ -49,3 +51,8 @@ int op::BroadcastDistributed::get_root_id() const ...@@ -49,3 +51,8 @@ int op::BroadcastDistributed::get_root_id() const
{ {
return m_root_id; return m_root_id;
} }
void op::BroadcastDistributed::set_root_id(int root_id)
{
m_root_id = root_id;
}
...@@ -27,16 +27,21 @@ namespace ngraph ...@@ -27,16 +27,21 @@ namespace ngraph
class BroadcastDistributed : public Op class BroadcastDistributed : public Op
{ {
public: public:
BroadcastDistributed(const std::shared_ptr<Node>& arg, int root_id = 0); NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
BroadcastDistributed() = default;
BroadcastDistributed(const Output<Node>& arg, int root_id = 0);
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
int get_root_id() const; int get_root_id() const;
void set_root_id(int root_id);
private: private:
const int m_root_id; int m_root_id;
}; };
} }
} }
...@@ -19,8 +19,10 @@ ...@@ -19,8 +19,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Ceiling::Ceiling(const shared_ptr<Node>& arg) const string op::Ceiling::type_name{"Ceiling"};
: UnaryElementwiseArithmetic("Ceiling", arg)
op::Ceiling::Ceiling(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,10 +26,15 @@ namespace ngraph ...@@ -26,10 +26,15 @@ namespace ngraph
class Ceiling : public util::UnaryElementwiseArithmetic class Ceiling : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a ceiling operation.
Ceiling() = default;
/// \brief Constructs a ceiling operation. /// \brief Constructs a ceiling operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Ceiling(const std::shared_ptr<Node>& arg); Ceiling(const Output<Node>& arg);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -22,13 +22,20 @@ ...@@ -22,13 +22,20 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Concat::Concat(const NodeVector& args, size_t concatenation_axis) const string op::Concat::type_name{"Concat"};
: Op("Concat", check_single_output_args(args))
op::Concat::Concat(const OutputVector& args, size_t concatenation_axis)
: Op(args)
, m_concatenation_axis(concatenation_axis) , m_concatenation_axis(concatenation_axis)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::Concat::Concat(const NodeVector& args, size_t concatenation_axis)
: Concat(as_output_vector(args), concatenation_axis)
{
}
void op::Concat::validate_and_infer_types() void op::Concat::validate_and_infer_types()
{ {
NODE_VALIDATION_CHECK(this, get_input_size() >= 1, "At least one argument required."); NODE_VALIDATION_CHECK(this, get_input_size() >= 1, "At least one argument required.");
......
...@@ -28,6 +28,17 @@ namespace ngraph ...@@ -28,6 +28,17 @@ namespace ngraph
class Concat : public Op class Concat : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a concatenation operation.
Concat() = default;
/// \brief Constructs a concatenation operation.
///
/// \param args The outputs producing the input tensors.
/// \param concatenation_axis The axis along which to concatenate the input tensors.
Concat(const OutputVector& args, size_t concatenation_axis);
/// \brief Constructs a concatenation operation. /// \brief Constructs a concatenation operation.
/// ///
/// \param args The nodes producing the input tensors. /// \param args The nodes producing the input tensors.
...@@ -41,10 +52,15 @@ namespace ngraph ...@@ -41,10 +52,15 @@ namespace ngraph
/// \return The concatenation axis. /// \return The concatenation axis.
size_t get_concatenation_axis() const { return m_concatenation_axis; } size_t get_concatenation_axis() const { return m_concatenation_axis; }
void set_concatenation_axis(size_t concatenation_axis)
{
m_concatenation_axis = concatenation_axis;
}
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
const size_t m_concatenation_axis; size_t m_concatenation_axis;
}; };
} }
} }
...@@ -45,6 +45,8 @@ string to_cpp_string(T value) ...@@ -45,6 +45,8 @@ string to_cpp_string(T value)
return rc; return rc;
} }
const string op::Constant::type_name{"Constant"};
op::Constant::~Constant() op::Constant::~Constant()
{ {
} }
......
...@@ -34,6 +34,9 @@ namespace ngraph ...@@ -34,6 +34,9 @@ namespace ngraph
class Constant : public Node class Constant : public Node
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a tensor constant. /// \brief Constructs a tensor constant.
/// ///
/// \param type The element type of the tensor constant. /// \param type The element type of the tensor constant.
...@@ -78,7 +81,7 @@ namespace ngraph ...@@ -78,7 +81,7 @@ namespace ngraph
/// \param shape The shape of the tensor constant. /// \param shape The shape of the tensor constant.
/// \param values A list of string values to use as the constant data. /// \param values A list of string values to use as the constant data.
Constant(const element::Type& type, Shape shape, const std::vector<std::string>& values) Constant(const element::Type& type, Shape shape, const std::vector<std::string>& values)
: Node("Constant", {}) : Node({})
, m_element_type(type) , m_element_type(type)
, m_shape(shape) , m_shape(shape)
, m_data(new runtime::AlignedBuffer(shape_size(m_shape) * m_element_type.size(), , m_data(new runtime::AlignedBuffer(shape_size(m_shape) * m_element_type.size(),
...@@ -135,7 +138,7 @@ namespace ngraph ...@@ -135,7 +138,7 @@ namespace ngraph
/// \param shape The shape of the tensor constant. /// \param shape The shape of the tensor constant.
/// \param data A void* to constant data. /// \param data A void* to constant data.
Constant(const element::Type& type, const Shape& shape, const void* data) Constant(const element::Type& type, const Shape& shape, const void* data)
: Node("Constant", {}) : Node({})
, m_element_type(type) , m_element_type(type)
, m_shape(shape) , m_shape(shape)
, m_data(nullptr) , m_data(nullptr)
......
...@@ -21,8 +21,10 @@ ...@@ -21,8 +21,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Convert::Convert(const shared_ptr<Node>& arg, const element::Type& element_type) const string op::Convert::type_name{"Convert"};
: Op("Convert", check_single_output_args({arg}))
op::Convert::Convert(const Output<Node>& arg, const element::Type& element_type)
: Op({arg})
, m_element_type(element_type) , m_element_type(element_type)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
......
...@@ -26,11 +26,16 @@ namespace ngraph ...@@ -26,11 +26,16 @@ namespace ngraph
class Convert : public Op class Convert : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a conversion operation.
Convert() = default;
/// \brief Constructs a conversion operation. /// \brief Constructs a conversion operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
/// \param element_type Element type for the output tensor. /// \param element_type Element type for the output tensor.
Convert(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type); Convert(const Output<Node>& arg, const ngraph::element::Type& element_type);
void validate_and_infer_types() override; void validate_and_infer_types() override;
...@@ -38,8 +43,13 @@ namespace ngraph ...@@ -38,8 +43,13 @@ namespace ngraph
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
const element::Type& get_convert_element_type() const { return m_element_type; } const element::Type& get_convert_element_type() const { return m_element_type; }
void set_convert_element_type(const element::Type& element_type)
{
m_element_type = element_type;
}
protected: protected:
const ngraph::element::Type m_element_type; ngraph::element::Type m_element_type;
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
}; };
......
...@@ -27,15 +27,17 @@ ...@@ -27,15 +27,17 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Convolution::Convolution(const shared_ptr<Node>& data_batch, const string op::Convolution::type_name{"Convolution"};
const shared_ptr<Node>& filters,
op::Convolution::Convolution(const Output<Node>& data_batch,
const Output<Node>& filters,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above, const CoordinateDiff& padding_above,
const Strides& data_dilation_strides, const Strides& data_dilation_strides,
const PadType& pad_type) const PadType& pad_type)
: Op("Convolution", check_single_output_args({data_batch, filters})) : Op({data_batch, filters})
, m_window_movement_strides(window_movement_strides) , m_window_movement_strides(window_movement_strides)
, m_window_dilation_strides(window_dilation_strides) , m_window_dilation_strides(window_dilation_strides)
, m_padding_below(padding_below) , m_padding_below(padding_below)
...@@ -114,8 +116,8 @@ void op::Convolution::validate_and_infer_types() ...@@ -114,8 +116,8 @@ void op::Convolution::validate_and_infer_types()
set_output_type(0, result_et, result_shape); set_output_type(0, result_et, result_shape);
} }
op::Convolution::Convolution(const shared_ptr<Node>& data_batch, op::Convolution::Convolution(const Output<Node>& data_batch,
const shared_ptr<Node>& filters, const Output<Node>& filters,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides, const Strides& window_dilation_strides,
const CoordinateDiff& padding_below, const CoordinateDiff& padding_below,
...@@ -130,8 +132,8 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch, ...@@ -130,8 +132,8 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
{ {
} }
op::Convolution::Convolution(const shared_ptr<Node>& data_batch, op::Convolution::Convolution(const Output<Node>& data_batch,
const shared_ptr<Node>& filters, const Output<Node>& filters,
const Strides& window_movement_strides, const Strides& window_movement_strides,
const Strides& window_dilation_strides) const Strides& window_dilation_strides)
: Convolution(data_batch, : Convolution(data_batch,
...@@ -143,8 +145,8 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch, ...@@ -143,8 +145,8 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
{ {
} }
op::Convolution::Convolution(const shared_ptr<Node>& data_batch, op::Convolution::Convolution(const Output<Node>& data_batch,
const shared_ptr<Node>& filters, const Output<Node>& filters,
const Strides& window_movement_strides) const Strides& window_movement_strides)
: Convolution(data_batch, : Convolution(data_batch,
filters, filters,
...@@ -155,7 +157,7 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch, ...@@ -155,7 +157,7 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
{ {
} }
op::Convolution::Convolution(const shared_ptr<Node>& data_batch, const shared_ptr<Node>& filters) op::Convolution::Convolution(const Output<Node>& data_batch, const Output<Node>& filters)
: Convolution(data_batch, filters, Strides(), Strides(), CoordinateDiff(), CoordinateDiff()) : Convolution(data_batch, filters, Strides(), Strides(), CoordinateDiff(), CoordinateDiff())
{ {
} }
...@@ -204,15 +206,17 @@ void op::Convolution::generate_adjoints(autodiff::Adjoints& adjoints, const Node ...@@ -204,15 +206,17 @@ void op::Convolution::generate_adjoints(autodiff::Adjoints& adjoints, const Node
m_data_dilation_strides)); m_data_dilation_strides));
} }
const string op::ConvolutionBackpropData::type_name{"ConvolutionBackpropData"};
op::ConvolutionBackpropData::ConvolutionBackpropData(const Shape& data_batch_shape, op::ConvolutionBackpropData::ConvolutionBackpropData(const Shape& data_batch_shape,
const shared_ptr<Node>& filters, const Output<Node>& filters,
const shared_ptr<Node>& output_delta, const Output<Node>& output_delta,
const Strides& window_movement_strides_forward, const Strides& window_movement_strides_forward,
const Strides& window_dilation_strides_forward, const Strides& window_dilation_strides_forward,
const CoordinateDiff& padding_below_forward, const CoordinateDiff& padding_below_forward,
const CoordinateDiff& padding_above_forward, const CoordinateDiff& padding_above_forward,
const Strides& data_dilation_strides_forward) const Strides& data_dilation_strides_forward)
: Op("ConvolutionBackpropData", check_single_output_args({filters, output_delta})) : Op({filters, output_delta})
, m_data_batch_shape(data_batch_shape) , m_data_batch_shape(data_batch_shape)
, m_window_movement_strides_forward(window_movement_strides_forward) , m_window_movement_strides_forward(window_movement_strides_forward)
, m_window_dilation_strides_forward(window_dilation_strides_forward) , m_window_dilation_strides_forward(window_dilation_strides_forward)
...@@ -332,14 +336,14 @@ void op::ConvolutionBackpropData::generate_adjoints(autodiff::Adjoints& adjoints ...@@ -332,14 +336,14 @@ void op::ConvolutionBackpropData::generate_adjoints(autodiff::Adjoints& adjoints
m_data_dilation_strides_forward[i]); m_data_dilation_strides_forward[i]);
} }
auto swap_NC = [](const shared_ptr<Node> n) { auto swap_NC = [](const Output<Node>& n) {
AxisVector ax_order = ngraph::get_default_order(n->get_shape()); AxisVector ax_order = ngraph::get_default_order(n.get_shape());
ax_order[0] = 1; ax_order[0] = 1;
ax_order[1] = 0; ax_order[1] = 0;
auto new_shape = n->get_shape(); auto new_shape = n.get_shape();
new_shape[0] = n->get_shape()[1]; new_shape[0] = n.get_shape()[1];
new_shape[1] = n->get_shape()[0]; new_shape[1] = n.get_shape()[0];
return make_shared<op::Reshape>(n, ax_order, new_shape); return make_shared<op::Reshape>(n, ax_order, new_shape);
}; };
...@@ -422,16 +426,18 @@ CoordinateDiff op::ConvolutionBackpropData::compute_backward_delta_out_pad_above ...@@ -422,16 +426,18 @@ CoordinateDiff op::ConvolutionBackpropData::compute_backward_delta_out_pad_above
return backward_delta_out_pad_above; return backward_delta_out_pad_above;
} }
const string op::ConvolutionBackpropFilters::type_name{"ConvolutionBackpropFilters"};
op::ConvolutionBackpropFilters::ConvolutionBackpropFilters( op::ConvolutionBackpropFilters::ConvolutionBackpropFilters(
const shared_ptr<Node>& data_batch, const Output<Node>& data_batch,
const Shape& filters_shape, const Shape& filters_shape,
const shared_ptr<Node>& output_delta, const Output<Node>& output_delta,
const Strides& window_movement_strides_forward, const Strides& window_movement_strides_forward,
const Strides& window_dilation_strides_forward, const Strides& window_dilation_strides_forward,
const CoordinateDiff& padding_below_forward, const CoordinateDiff& padding_below_forward,
const CoordinateDiff& padding_above_forward, const CoordinateDiff& padding_above_forward,
const Strides& data_dilation_strides_forward) const Strides& data_dilation_strides_forward)
: Op("ConvolutionBackpropFilters", check_single_output_args({data_batch, output_delta})) : Op({data_batch, output_delta})
, m_filters_shape(filters_shape) , m_filters_shape(filters_shape)
, m_window_movement_strides_forward(window_movement_strides_forward) , m_window_movement_strides_forward(window_movement_strides_forward)
, m_window_dilation_strides_forward(window_dilation_strides_forward) , m_window_dilation_strides_forward(window_dilation_strides_forward)
......
This diff is collapsed.
...@@ -22,8 +22,10 @@ ...@@ -22,8 +22,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Cos::Cos(const shared_ptr<Node>& arg) const string op::Cos::type_name{"Cos"};
: UnaryElementwiseArithmetic("Cos", arg)
op::Cos::Cos(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,10 +26,15 @@ namespace ngraph ...@@ -26,10 +26,15 @@ namespace ngraph
class Cos : public util::UnaryElementwiseArithmetic class Cos : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a cosine operation.
Cos() = default;
/// \brief Constructs a cosine operation. /// \brief Constructs a cosine operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Cos(const std::shared_ptr<Node>& arg); Cos(const Output<Node>& arg);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -21,8 +21,10 @@ ...@@ -21,8 +21,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Cosh::Cosh(const shared_ptr<Node>& arg) const string op::Cosh::type_name{"Cosh"};
: UnaryElementwiseArithmetic("Cosh", arg)
op::Cosh::Cosh(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,10 +26,15 @@ namespace ngraph ...@@ -26,10 +26,15 @@ namespace ngraph
class Cosh : public util::UnaryElementwiseArithmetic class Cosh : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a hyperbolic cosine operation.
Cosh() = default;
/// \brief Constructs a hyperbolic cosine operation. /// \brief Constructs a hyperbolic cosine operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Cosh(const std::shared_ptr<Node>& arg); Cosh(const Output<Node>& arg);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -20,13 +20,15 @@ ...@@ -20,13 +20,15 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Dequantize::Dequantize(const shared_ptr<Node>& input, const string op::Dequantize::type_name{"Dequantize"};
const shared_ptr<Node>& scale,
const shared_ptr<Node>& zero_point, op::Dequantize::Dequantize(const Output<Node>& input,
const Output<Node>& scale,
const Output<Node>& zero_point,
const element::Type& type, const element::Type& type,
const AxisSet& axes) const AxisSet& axes)
: Op("Dequantize", check_single_output_args({input, scale, zero_point})) : Op({input, scale, zero_point})
, m_type(type) , m_type(type)
, m_axes(axes) , m_axes(axes)
{ {
......
...@@ -30,31 +30,40 @@ namespace ngraph ...@@ -30,31 +30,40 @@ namespace ngraph
class Dequantize : public ngraph::op::Op class Dequantize : public ngraph::op::Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a Dequantize operation
Dequantize() = default;
/// \brief Constructs a Dequantize operation /// \brief Constructs a Dequantize operation
/// \param input quantized input /// \param input quantized input
/// \param scale scale used for mapping /// \param scale scale used for mapping
/// \param zero_point zero point used for mapping /// \param zero_point zero point used for mapping
/// \param type output element type /// \param type output element type
/// \param axes axis positions on which `scale` and `zero_point` are specified /// \param axes axis positions on which `scale` and `zero_point` are specified
Dequantize(const std::shared_ptr<Node>& input, Dequantize(const Output<Node>& input,
const std::shared_ptr<Node>& scale, const Output<Node>& scale,
const std::shared_ptr<Node>& zero_point, const Output<Node>& zero_point,
const ngraph::element::Type& type, const element::Type& type,
const ngraph::AxisSet& axes); const AxisSet& axes);
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
const ngraph::AxisSet& get_axes() const { return m_axes; } const AxisSet& get_axes() const { return m_axes; }
void set_axes(const AxisSet& axes) { m_axes = axes; }
const element::Type& get_type() const { return m_type; }
void set_type(const element::Type& type) { m_type = type; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
private: private:
ngraph::element::Type m_type; element::Type m_type;
ngraph::AxisSet m_axes; AxisSet m_axes;
}; };
} }
} }
...@@ -21,20 +21,21 @@ ...@@ -21,20 +21,21 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Divide::Divide(const shared_ptr<Node>& arg0, const string op::Divide::type_name{"Divide"};
const shared_ptr<Node>& arg1,
op::Divide::Divide(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob) const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic("Divide", arg0, arg1, autob) : BinaryElementwiseArithmetic(arg0, arg1, autob)
, m_pythondiv(true)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::Divide::Divide(const shared_ptr<Node>& arg0, op::Divide::Divide(const Output<Node>& arg0,
const shared_ptr<Node>& arg1, const Output<Node>& arg1,
bool pythondiv, bool pythondiv,
const AutoBroadcastSpec& autob) const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic("Divide", arg0, arg1, autob) : BinaryElementwiseArithmetic(arg0, arg1, autob)
, m_pythondiv(pythondiv) , m_pythondiv(pythondiv)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
...@@ -63,7 +64,7 @@ void op::Divide::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto ...@@ -63,7 +64,7 @@ void op::Divide::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto
adjoints.add_delta(y, -delta * shared_from_this() / y); adjoints.add_delta(y, -delta * shared_from_this() / y);
} }
shared_ptr<Node> ngraph::operator/(const shared_ptr<Node> arg0, const shared_ptr<Node> arg1) shared_ptr<Node> ngraph::operator/(const Output<Node> arg0, const Output<Node> arg1)
{ {
return make_shared<op::Divide>(arg0, arg1); return make_shared<op::Divide>(arg0, arg1);
} }
...@@ -26,14 +26,19 @@ namespace ngraph ...@@ -26,14 +26,19 @@ namespace ngraph
class Divide : public util::BinaryElementwiseArithmetic class Divide : public util::BinaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a division operation.
Divide() = default;
/// \brief Constructs a division operation. /// \brief Constructs a division operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param pythondiv Use Python style rounding for integral type /// \param pythondiv Use Python style rounding for integral type
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
Divide(const std::shared_ptr<Node>& arg0, Divide(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
bool pythondiv, bool pythondiv,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
...@@ -42,11 +47,12 @@ namespace ngraph ...@@ -42,11 +47,12 @@ namespace ngraph
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
Divide(const std::shared_ptr<Node>& arg0, Divide(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
bool is_pythondiv() const { return m_pythondiv; } bool is_pythondiv() const { return m_pythondiv; }
void set_is_pythondiv(bool pythondiv) { m_pythondiv = pythondiv; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
...@@ -54,10 +60,10 @@ namespace ngraph ...@@ -54,10 +60,10 @@ namespace ngraph
const NodeVector& deltas) override; const NodeVector& deltas) override;
protected: protected:
bool m_pythondiv; bool m_pythondiv{true};
}; };
} }
std::shared_ptr<ngraph::Node> operator/(const std::shared_ptr<ngraph::Node> arg0, std::shared_ptr<ngraph::Node> operator/(const Output<ngraph::Node> arg0,
const std::shared_ptr<ngraph::Node> arg1); const Output<ngraph::Node> arg1);
} }
...@@ -29,16 +29,18 @@ ...@@ -29,16 +29,18 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Dot::Dot(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1) const string op::Dot::type_name{"Dot"};
op::Dot::Dot(const Output<Node>& arg0, const Output<Node>& arg1)
: Dot(arg0, arg1, 0, false) : Dot(arg0, arg1, 0, false)
{ {
} }
op::Dot::Dot(const shared_ptr<Node>& arg0, op::Dot::Dot(const Output<Node>& arg0,
const shared_ptr<Node>& arg1, const Output<Node>& arg1,
size_t reduction_axes_count, size_t reduction_axes_count,
bool has_reduction_axes_count) bool has_reduction_axes_count)
: Op("Dot", check_single_output_args({arg0, arg1})) : Op({arg0, arg1})
, m_reduction_axes_count(reduction_axes_count) , m_reduction_axes_count(reduction_axes_count)
, m_has_reduction_axes_count(has_reduction_axes_count) , m_has_reduction_axes_count(has_reduction_axes_count)
{ {
...@@ -154,7 +156,7 @@ void op::Dot::validate_and_infer_types() ...@@ -154,7 +156,7 @@ void op::Dot::validate_and_infer_types()
set_output_type(0, result_et, result_shape); set_output_type(0, result_et, result_shape);
} }
shared_ptr<op::Reshape> make_reshape_axes_to_front(const shared_ptr<Node>& n, shared_ptr<op::Reshape> make_reshape_axes_to_front(const Output<Node>& n,
const Shape& front_shape, const Shape& front_shape,
const Shape& back_shape) const Shape& back_shape)
{ {
......
...@@ -28,13 +28,18 @@ namespace ngraph ...@@ -28,13 +28,18 @@ namespace ngraph
class Dot : public Op class Dot : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a dot product operation.
Dot() = default;
/// \brief Constructs a dot product operation. /// \brief Constructs a dot product operation.
/// ///
/// \param arg0 The node producing the first argument. /// \param arg0 The node producing the first argument.
/// \param arg1 The node producing the second argument. /// \param arg1 The node producing the second argument.
/// \param reduction_axes_count The number of axes to dot. /// \param reduction_axes_count The number of axes to dot.
Dot(const std::shared_ptr<Node>& arg0, Dot(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
size_t reduction_axes_count, size_t reduction_axes_count,
bool has_reduction_axes_count = true); bool has_reduction_axes_count = true);
...@@ -48,11 +53,20 @@ namespace ngraph ...@@ -48,11 +53,20 @@ namespace ngraph
/// ///
/// \param arg0 The node producing the first argument. /// \param arg0 The node producing the first argument.
/// \param arg1 The node producing the second argument. /// \param arg1 The node producing the second argument.
Dot(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1); Dot(const Output<Node>& arg0, const Output<Node>& arg1);
void validate_and_infer_types() override; void validate_and_infer_types() override;
size_t get_reduction_axes_count() const { return m_reduction_axes_count; } size_t get_reduction_axes_count() const { return m_reduction_axes_count; }
void get_reduction_axes_count(size_t reduction_axes_count)
{
m_reduction_axes_count = reduction_axes_count;
}
bool get_has_reduction_axes_count() const { return m_has_reduction_axes_count; }
void set_has_reduction_axes_count(bool has_reduction_axes_count)
{
m_has_reduction_axes_count = has_reduction_axes_count;
}
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override copy_with_new_args(const NodeVector& new_args) const override
{ {
......
...@@ -19,6 +19,8 @@ ...@@ -19,6 +19,8 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::EmbeddingLookup::type_name{"EmbeddingLookup"};
void op::EmbeddingLookup::validate_and_infer_types() void op::EmbeddingLookup::validate_and_infer_types()
{ {
element::Type result_et = get_input_element_type(1); element::Type result_et = get_input_element_type(1);
......
...@@ -28,6 +28,11 @@ namespace ngraph ...@@ -28,6 +28,11 @@ namespace ngraph
class EmbeddingLookup : public Op class EmbeddingLookup : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a EmbeddingLookup operation.
EmbeddingLookup() = default;
/// \brief Constructs a EmbeddingLookup operation. /// \brief Constructs a EmbeddingLookup operation.
/// ///
/// EmbeddingLookup constructs an output tensor by replacing every index in a given input tensor /// EmbeddingLookup constructs an output tensor by replacing every index in a given input tensor
...@@ -36,8 +41,8 @@ namespace ngraph ...@@ -36,8 +41,8 @@ namespace ngraph
/// \param data The input indices for tokens to be translated into embeddings /// \param data The input indices for tokens to be translated into embeddings
/// \param weights is a dense matrix [N,M] where each row 0..N /// \param weights is a dense matrix [N,M] where each row 0..N
/// corresponds to an embedding (i.e. typically, a vector of real numbers) of length M /// corresponds to an embedding (i.e. typically, a vector of real numbers) of length M
EmbeddingLookup(const std::shared_ptr<Node>& data, const std::shared_ptr<Node>& weights) EmbeddingLookup(const Output<Node>& data, const Output<Node>& weights)
: Op("EmbeddingLookup", check_single_output_args({data, weights})) : Op({data, weights})
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -19,10 +19,10 @@ ...@@ -19,10 +19,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Equal::Equal(const shared_ptr<Node>& arg0, const string op::Equal::type_name{"Equal"};
const shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob) op::Equal::Equal(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& autob)
: BinaryElementwiseComparison("Equal", arg0, arg1, autob) : BinaryElementwiseComparison(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -40,13 +40,18 @@ namespace ngraph ...@@ -40,13 +40,18 @@ namespace ngraph
class Equal : public util::BinaryElementwiseComparison class Equal : public util::BinaryElementwiseComparison
{ {
public: public:
/// \brief Constructs an is-equal operation. NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an equal operation.
Equal() = default;
/// \brief Constructs an equal operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
Equal(const std::shared_ptr<Node>& arg0, Equal(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -21,14 +21,16 @@ ...@@ -21,14 +21,16 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
const string op::Erf::type_name{"Erf"};
shared_ptr<Node> op::Erf::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Erf::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<Erf>(new_args.at(0)); return make_shared<Erf>(new_args.at(0));
} }
op::Erf::Erf(shared_ptr<Node> arg) op::Erf::Erf(const Output<Node>& arg)
: UnaryElementwiseArithmetic("Erf", arg) : UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -27,7 +27,11 @@ namespace ngraph ...@@ -27,7 +27,11 @@ namespace ngraph
class Erf : public util::UnaryElementwiseArithmetic class Erf : public util::UnaryElementwiseArithmetic
{ {
public: public:
Erf(std::shared_ptr<Node> arg); NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
Erf() = default;
Erf(const Output<Node>& arg);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -20,8 +20,10 @@ ...@@ -20,8 +20,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Exp::Exp(const shared_ptr<Node>& arg) const string op::Exp::type_name{"Exp"};
: UnaryElementwiseArithmetic("Exp", arg)
op::Exp::Exp(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,10 +26,15 @@ namespace ngraph ...@@ -26,10 +26,15 @@ namespace ngraph
class Exp : public util::UnaryElementwiseArithmetic class Exp : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an exponential operation.
Exp() = default;
/// \brief Constructs an exponential operation. /// \brief Constructs an exponential operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Exp(const std::shared_ptr<Node>& arg); Exp(const Output<Node>& arg);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -19,8 +19,10 @@ ...@@ -19,8 +19,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Floor::Floor(const shared_ptr<Node>& arg) const string op::Floor::type_name{"Floor"};
: UnaryElementwiseArithmetic("Floor", arg)
op::Floor::Floor(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,10 +26,15 @@ namespace ngraph ...@@ -26,10 +26,15 @@ namespace ngraph
class Floor : public util::UnaryElementwiseArithmetic class Floor : public util::UnaryElementwiseArithmetic
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a floor operation.
Floor() = default;
/// \brief Constructs a floor operation. /// \brief Constructs a floor operation.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Floor(const std::shared_ptr<Node>& arg); Floor(const Output<Node>& arg);
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -24,10 +24,12 @@ ...@@ -24,10 +24,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Reshape::Reshape(const shared_ptr<Node>& arg, const string op::Reshape::type_name{"Reshape"};
op::Reshape::Reshape(const Output<Node>& arg,
const AxisVector& input_order, const AxisVector& input_order,
const Shape& output_shape) const Shape& output_shape)
: Op("Reshape", check_single_output_args({arg})) : Op({arg})
, m_input_order(input_order) , m_input_order(input_order)
, m_output_shape(output_shape) , m_output_shape(output_shape)
{ {
......
...@@ -60,6 +60,11 @@ namespace ngraph ...@@ -60,6 +60,11 @@ namespace ngraph
class Reshape : public Op class Reshape : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a reshape operation.
Reshape() = default;
/// \brief Constructs a reshape operation. /// \brief Constructs a reshape operation.
/// ///
/// \param arg The tensor to be reshaped. /// \param arg The tensor to be reshaped.
...@@ -67,7 +72,7 @@ namespace ngraph ...@@ -67,7 +72,7 @@ namespace ngraph
/// sequence \f$(0,\dots,n-1)\f$ where \f$n\f$ is the rank of the input tensor. /// sequence \f$(0,\dots,n-1)\f$ where \f$n\f$ is the rank of the input tensor.
/// \param output_shape The output shape. If the input shape is \f$(a_0,\dots,a_{k-1})\f$ then the output shape must /// \param output_shape The output shape. If the input shape is \f$(a_0,\dots,a_{k-1})\f$ then the output shape must
/// be of the form \f$(b_0,\dots,b_{j-1})\f$ where \f$\Pi(a_i) = \Pi(b_i)\f$. /// be of the form \f$(b_0,\dots,b_{j-1})\f$ where \f$\Pi(a_i) = \Pi(b_i)\f$.
Reshape(const std::shared_ptr<Node>& arg, Reshape(const Output<Node>& arg,
const AxisVector& input_order, const AxisVector& input_order,
const Shape& output_shape); const Shape& output_shape);
...@@ -78,15 +83,18 @@ namespace ngraph ...@@ -78,15 +83,18 @@ namespace ngraph
/// \return The order in which to iterate over input axes. /// \return The order in which to iterate over input axes.
const AxisVector& get_input_order() const { return m_input_order; } const AxisVector& get_input_order() const { return m_input_order; }
void set_input_order(const AxisVector& input_order) { m_input_order = input_order; }
/// \return The shape of the output tensor. /// \return The shape of the output tensor.
const Shape& get_output_shape() const { return m_output_shape; } const Shape& get_output_shape() const { return m_output_shape; }
void set_output_shape(const Shape& output_shape) { m_output_shape = output_shape; }
bool get_is_transpose() const { return m_is_transpose; } bool get_is_transpose() const { return m_is_transpose; }
void set_is_transpose(bool is_transpose) { m_is_transpose = is_transpose; }
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
const AxisVector m_input_order; AxisVector m_input_order;
const Shape m_output_shape; Shape m_output_shape;
bool m_is_transpose{false}; bool m_is_transpose{false};
}; };
} }
......
...@@ -24,8 +24,10 @@ ...@@ -24,8 +24,10 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Result::Result(const shared_ptr<Node>& arg, bool needs_default_layout) const string op::Result::type_name{"Result"};
: Op("Result", check_single_output_args({arg}))
op::Result::Result(const Output<Node>& arg, bool needs_default_layout)
: Op({arg})
, m_needs_default_layout(needs_default_layout) , m_needs_default_layout(needs_default_layout)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
......
...@@ -27,10 +27,15 @@ namespace ngraph ...@@ -27,10 +27,15 @@ namespace ngraph
class Result : public Op class Result : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Allows a value to be used as a function result.
Result() = default;
/// \brief Allows a value to be used as a function result. /// \brief Allows a value to be used as a function result.
/// ///
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Result(const std::shared_ptr<Node>& arg, bool needs_default_layout = false); Result(const Output<Node>& arg, bool needs_default_layout = false);
void validate_and_infer_types() override; void validate_and_infer_types() override;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment