Unverified Commit 0552fd84 authored by Scott Cyphers's avatar Scott Cyphers Committed by GitHub

Convert to new op form (#3112)

parent 5e7aacf1
......@@ -559,6 +559,16 @@ const NodeVector& ngraph::check_single_output_args(const NodeVector& args)
return args;
}
OutputVector ngraph::as_output_vector(const NodeVector& args)
{
OutputVector output_vector;
for (auto& arg : check_single_output_args(args))
{
output_vector.push_back(arg);
}
return output_vector;
}
std::tuple<element::Type, PartialShape>
Node::validate_and_infer_elementwise_args(const op::AutoBroadcastSpec& autob)
{
......
......@@ -73,6 +73,8 @@ namespace ngraph
size_t i);
const NodeVector& check_single_output_args(const NodeVector& args);
OutputVector as_output_vector(const NodeVector& args);
/// Alias useful for cloning
using NodeMap = std::unordered_map<ngraph::Node*, std::shared_ptr<ngraph::Node>>;
......
......@@ -23,10 +23,6 @@ using namespace ngraph;
const string op::Abs::type_name{"Abs"};
op::Abs::Abs()
{
}
op::Abs::Abs(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
......
......@@ -33,7 +33,7 @@ namespace ngraph
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an absolute value operation.
Abs();
Abs() = default;
/// \brief Constructs an absolute value operation.
///
......
......@@ -34,10 +34,6 @@ using namespace ngraph;
const string op::Acos::type_name{"Acos"};
op::Acos::Acos()
{
}
op::Acos::Acos(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
......
......@@ -33,7 +33,7 @@ namespace ngraph
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an arccos operation.
Acos();
Acos() = default;
/// \brief Constructs an arccos operation.
///
/// \param arg Output that produces the input tensor.<br>
......
......@@ -21,10 +21,6 @@ using namespace ngraph;
const string op::Add::type_name{"Add"};
op::Add::Add()
{
}
op::Add::Add(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic(arg0, arg1, autob)
{
......
......@@ -33,7 +33,7 @@ namespace ngraph
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an unitialized addition operation
Add();
Add() = default;
/// \brief Constructs an addition operation.
///
......
......@@ -21,10 +21,6 @@ using namespace ngraph;
const string op::All::type_name{"All"};
op::All::All()
{
}
op::All::All(const Output<Node>& arg, const AxisSet& reduction_axes)
: LogicalReduction(arg, reduction_axes)
{
......
......@@ -33,7 +33,7 @@ namespace ngraph
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an "all" reduction operation.
All();
All() = default;
/// \brief Constructs an "all" reduction operation.
///
/// \param arg The tensor to be reduced.
......
......@@ -21,10 +21,6 @@ using namespace ngraph;
const string op::And::type_name{"And"};
op::And::And()
{
}
op::And::And(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& autob)
: BinaryElementwiseLogical(arg0, arg1, autob)
{
......
......@@ -33,7 +33,7 @@ namespace ngraph
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a logical-and operation.
And();
And() = default;
/// \brief Constructs a logical-and operation.
///
......
......@@ -21,10 +21,6 @@ using namespace ngraph;
const string op::Any::type_name{"Any"};
op::Any::Any()
{
}
op::Any::Any(const Output<Node>& arg, const AxisSet& reduction_axes)
: LogicalReduction(arg, reduction_axes)
{
......
......@@ -33,7 +33,7 @@ namespace ngraph
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an "any" reduction operation.
Any();
Any() = default;
/// \brief Constructs an "any" reduction operation.
///
/// \param arg The tensor to be reduced.
......
......@@ -21,10 +21,6 @@ using namespace ngraph;
const string op::ArgMax::type_name{"ArgMax"};
op::ArgMax::ArgMax()
{
}
op::ArgMax::ArgMax(const Output<Node>& arg, size_t axis, const element::Type& index_element_type)
: op::util::IndexReduction(arg, axis, index_element_type)
{
......
......@@ -32,7 +32,7 @@ namespace ngraph
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a ArgMax operation.
ArgMax();
ArgMax() = default;
/// \brief Constructs a ArgMax operation.
///
/// \param arg The input tensor
......
......@@ -21,10 +21,6 @@ using namespace ngraph;
const string op::ArgMin::type_name{"ArgMin"};
op::ArgMin::ArgMin()
{
}
op::ArgMin::ArgMin(const Output<Node>& arg, size_t axis, const element::Type& index_element_type)
: op::util::IndexReduction(arg, axis, index_element_type)
{
......
......@@ -32,7 +32,7 @@ namespace ngraph
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a ArgMin operation.
ArgMin();
ArgMin() = default;
/// \brief Constructs a ArgMin operation.
///
......
......@@ -33,10 +33,6 @@ using namespace ngraph;
const string op::Asin::type_name{"Asin"};
op::Asin::Asin()
{
}
op::Asin::Asin(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
......
......@@ -33,7 +33,7 @@ namespace ngraph
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an arcsin operation.
Asin();
Asin() = default;
/// \brief Constructs an arcsin operation.
///
/// \param arg Output that produces the input tensor.<br>
......
......@@ -32,10 +32,6 @@ using namespace ngraph;
const string op::Atan::type_name{"Atan"};
op::Atan::Atan()
{
}
op::Atan::Atan(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
......
......@@ -33,7 +33,7 @@ namespace ngraph
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an arctan operation.
Atan();
Atan() = default;
/// \brief Constructs an arctan operation.
///
......
......@@ -23,10 +23,6 @@ using namespace ngraph;
const string op::AvgPool::type_name{"AvgPool"};
op::AvgPool::AvgPool()
{
}
op::AvgPool::AvgPool(const Output<Node>& arg,
const Shape& window_shape,
const Strides& window_movement_strides,
......@@ -231,10 +227,6 @@ shared_ptr<Node> op::AvgPool::copy_with_new_args(const NodeVector& new_args) con
const string op::AvgPoolBackprop::type_name("AvgPoolBackprop");
op::AvgPoolBackprop::AvgPoolBackprop()
{
}
op::AvgPoolBackprop::AvgPoolBackprop(const Shape& forward_arg_shape,
const shared_ptr<Node>& delta,
const Shape& window_shape,
......
......@@ -33,7 +33,7 @@ namespace ngraph
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a batched average pooling operation.
AvgPool();
AvgPool() = default;
/// \brief Constructs a batched average pooling operation.
///
......@@ -175,7 +175,7 @@ namespace ngraph
public:
static const std::string type_name;
const std::string& description() const override { return type_name; }
AvgPoolBackprop();
AvgPoolBackprop() = default;
AvgPoolBackprop(const Shape& forward_arg_shape,
const std::shared_ptr<Node>& delta,
const Shape& window_shape,
......
......@@ -22,11 +22,13 @@
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/validation_util.hpp"
ngraph::op::BatchNormTraining::BatchNormTraining(std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
const std::string ngraph::op::BatchNormTraining::type_name{"BatchNormTraining"};
ngraph::op::BatchNormTraining::BatchNormTraining(Output<ngraph::Node> input,
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta,
double epsilon)
: Op("BatchNormTraining", check_single_output_args({gamma, beta, input}))
: Op({gamma, beta, input})
, m_epsilon(epsilon)
{
constructor_validate_and_infer_types();
......@@ -34,10 +36,10 @@ ngraph::op::BatchNormTraining::BatchNormTraining(std::shared_ptr<ngraph::Node> i
// DEPRECATED
ngraph::op::BatchNormTraining::BatchNormTraining(double eps,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input)
: Op("BatchNormTraining", check_single_output_args({gamma, beta, input}))
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta,
Output<ngraph::Node> input)
: Op({gamma, beta, input})
, m_epsilon(eps)
{
constructor_validate_and_infer_types();
......@@ -111,13 +113,15 @@ void ngraph::op::BatchNormTraining::generate_adjoints(autodiff::Adjoints& adjoin
adjoints.add_delta(beta, dbeta);
}
ngraph::op::BatchNormInference::BatchNormInference(std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance,
const std::string ngraph::op::BatchNormInference::type_name{"BatchNormInference"};
ngraph::op::BatchNormInference::BatchNormInference(Output<ngraph::Node> input,
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta,
Output<ngraph::Node> mean,
Output<ngraph::Node> variance,
double epsilon)
: Op("BatchNormInference", check_single_output_args({gamma, beta, input, mean, variance}))
: Op({gamma, beta, input, mean, variance})
, m_epsilon(epsilon)
{
constructor_validate_and_infer_types();
......@@ -125,12 +129,12 @@ ngraph::op::BatchNormInference::BatchNormInference(std::shared_ptr<ngraph::Node>
// DEPRECATED
ngraph::op::BatchNormInference::BatchNormInference(double eps,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance)
: Op("BatchNormInference", check_single_output_args({gamma, beta, input, mean, variance}))
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta,
Output<ngraph::Node> input,
Output<ngraph::Node> mean,
Output<ngraph::Node> variance)
: Op({gamma, beta, input, mean, variance})
, m_epsilon(eps)
{
constructor_validate_and_infer_types();
......@@ -167,16 +171,16 @@ std::shared_ptr<ngraph::Node>
new_args.at(2), new_args.at(0), new_args.at(1), new_args.at(3), new_args.at(4), m_epsilon);
}
ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(
std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance,
std::shared_ptr<ngraph::Node> delta,
double epsilon)
: Op("BatchNormTrainingBackprop",
check_single_output_args({gamma, beta, input, mean, variance, delta}))
const std::string ngraph::op::BatchNormTrainingBackprop::type_name{"BatchNormTrainingBackprop"};
ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(Output<ngraph::Node> input,
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta,
Output<ngraph::Node> mean,
Output<ngraph::Node> variance,
Output<ngraph::Node> delta,
double epsilon)
: Op({gamma, beta, input, mean, variance, delta})
, m_epsilon(epsilon)
{
......@@ -184,16 +188,14 @@ ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(
constructor_validate_and_infer_types();
}
ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(
double epsilon,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance,
std::shared_ptr<ngraph::Node> delta)
: Op("BatchNormTrainingBackprop",
check_single_output_args({gamma, beta, input, mean, variance, delta}))
ngraph::op::BatchNormTrainingBackprop::BatchNormTrainingBackprop(double epsilon,
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta,
Output<ngraph::Node> input,
Output<ngraph::Node> mean,
Output<ngraph::Node> variance,
Output<ngraph::Node> delta)
: Op({gamma, beta, input, mean, variance, delta})
, m_epsilon(epsilon)
{
......
......@@ -31,13 +31,17 @@ namespace ngraph
class BatchNormTraining : public Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
BatchNormTraining() = default;
/// \param input Must have rank >= 2, [., C, ...]
/// \param gamma gamma scaling for normalized value. [C]
/// \param beta bias added to the scaled normalized value [C]
/// \param epsilon Avoids divsion by 0 if input has 0 variance
BatchNormTraining(std::shared_ptr<Node> input,
std::shared_ptr<Node> gamma,
std::shared_ptr<Node> beta,
BatchNormTraining(Output<Node> input,
Output<Node> gamma,
Output<Node> beta,
double epsilon);
NGRAPH_DEPRECATED_DOC
......@@ -62,13 +66,14 @@ namespace ngraph
/// output[2]: shall have rank 1, with the same span as input's channel axis.
NGRAPH_DEPRECATED("Use another constructor")
BatchNormTraining(double eps,
std::shared_ptr<Node> gamma,
std::shared_ptr<Node> beta,
std::shared_ptr<Node> input);
Output<Node> gamma,
Output<Node> beta,
Output<Node> input);
void validate_and_infer_types() override;
double get_eps_value() const { return m_epsilon; }
void set_eps_value(double epsilon) { m_epsilon = epsilon; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......@@ -87,17 +92,20 @@ namespace ngraph
class BatchNormInference : public Op
{
public:
static const std::string type_name;
const std::string& description() const override { return type_name; }
BatchNormInference() = default;
/// \param input [., C, ...]
/// \param gamma gamma scaling for normalized value. [C]
/// \param beta bias added to the scaled normalized value [C]
/// \param mean value for mean normalization [C]
/// \param variance value for variance normalization [C]
/// \param epsilon Avoids divsion by 0 if input has 0 variance
BatchNormInference(std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance,
BatchNormInference(Output<ngraph::Node> input,
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta,
Output<ngraph::Node> mean,
Output<ngraph::Node> variance,
double epsilon);
NGRAPH_DEPRECATED_DOC
......@@ -120,15 +128,16 @@ namespace ngraph
/// output: shall have the same shape as 'input'.
NGRAPH_DEPRECATED("Use another constructor")
BatchNormInference(double eps,
std::shared_ptr<ngraph::Node> gamma,
std::shared_ptr<ngraph::Node> beta,
std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> mean,
std::shared_ptr<ngraph::Node> variance);
Output<ngraph::Node> gamma,
Output<ngraph::Node> beta,
Output<ngraph::Node> input,
Output<ngraph::Node> mean,
Output<ngraph::Node> variance);
void validate_and_infer_types() override;
double get_eps_value() const { return m_epsilon; }
void set_eps_value(double epsilon) { m_epsilon = epsilon; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......@@ -152,28 +161,33 @@ namespace ngraph
class BatchNormTrainingBackprop : public Op
{
public:
BatchNormTrainingBackprop(std::shared_ptr<Node> input,
std::shared_ptr<Node> gamma,
std::shared_ptr<Node> beta,
std::shared_ptr<Node> mean,
std::shared_ptr<Node> variance,
std::shared_ptr<Node> delta,
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
BatchNormTrainingBackprop() = default;
BatchNormTrainingBackprop(Output<Node> input,
Output<Node> gamma,
Output<Node> beta,
Output<Node> mean,
Output<Node> variance,
Output<Node> delta,
double epsilon);
NGRAPH_DEPRECATED_DOC
NGRAPH_DEPRECATED("Use another constructor")
BatchNormTrainingBackprop(double epsilon,
std::shared_ptr<Node> gamma,
std::shared_ptr<Node> beta,
std::shared_ptr<Node> input,
Output<Node> gamma,
Output<Node> beta,
Output<Node> input,
std::shared_ptr<Node> mean,
std::shared_ptr<Node> variance,
std::shared_ptr<Node> delta);
Output<Node> mean,
Output<Node> variance,
Output<Node> delta);
void validate_and_infer_types() override;
double get_eps_value() const { return m_epsilon; }
void set_eps_value(double epsilon) { m_epsilon = epsilon; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -20,21 +20,20 @@
using namespace std;
using namespace ngraph;
op::Broadcast::Broadcast(const std::string& name,
const NodeVector& args,
const string op::Broadcast::type_name{"Broadcast"};
op::Broadcast::Broadcast(const OutputVector& args,
const Shape& shape,
const AxisSet& broadcast_axes)
: Op(name, check_single_output_args(args))
: Op(args)
, m_shape(shape)
, m_broadcast_axes(broadcast_axes)
{
constructor_validate_and_infer_types();
}
op::Broadcast::Broadcast(const shared_ptr<Node>& arg,
const Shape& shape,
const AxisSet& broadcast_axes)
: Broadcast("Broadcast", {arg}, shape, broadcast_axes)
op::Broadcast::Broadcast(const Output<Node>& arg, const Shape& shape, const AxisSet& broadcast_axes)
: Broadcast(OutputVector{arg}, shape, broadcast_axes)
{
}
......@@ -96,10 +95,12 @@ void op::Broadcast::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVe
adjoints.add_delta(x, make_shared<op::Sum>(delta, m_broadcast_axes));
}
op::BroadcastLike::BroadcastLike(const std::shared_ptr<Node>& arg,
const std::shared_ptr<Node>& like_arg,
const string op::BroadcastLike::type_name{"BroadcastLike"};
op::BroadcastLike::BroadcastLike(const Output<Node>& arg,
const Output<Node>& like_arg,
const AxisSet& initial_broadcast_axes)
: Broadcast("BroadcastLike", {arg, like_arg}, {}, {})
: Broadcast({arg, like_arg}, {}, {})
, m_initial_broadcast_axes(initial_broadcast_axes)
{
constructor_validate_and_infer_types();
......
......@@ -27,15 +27,18 @@ namespace ngraph
class Broadcast : public Op
{
public:
/// \brief Constructs a conversion operation.
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a broadcast operation.
Broadcast() = default;
/// \brief Constructs a broadcast operation.
///
/// \param arg Node that produces the input tensor to be broadcast.
/// \param shape The shape of the output tensor.
/// \param broadcast_axes The axis positions (0-based) in the result that are being broadcast. The
/// remaining axes in shape must be the same as the shape of arg.
Broadcast(const std::shared_ptr<Node>& arg,
const Shape& shape,
const AxisSet& broadcast_axes);
Broadcast(const Output<Node>& arg, const Shape& shape, const AxisSet& broadcast_axes);
void validate_and_infer_types() override;
......@@ -44,12 +47,14 @@ namespace ngraph
/// \return A set containing the indices of the broadcast axes (0-based).
const AxisSet& get_broadcast_axes() const { return m_broadcast_axes; }
void set_broadcast_axes(const AxisSet& broadcast_axes)
{
m_broadcast_axes = broadcast_axes;
}
const Shape& get_broadcast_shape() const { return m_shape; }
void set_broadcast_shape(const Shape& shape) { m_shape = shape; }
protected:
Broadcast(const std::string& node_type,
const NodeVector& args,
const Shape& shape,
const AxisSet& broadcast_axes);
Broadcast(const OutputVector& args, const Shape& shape, const AxisSet& broadcast_axes);
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
......@@ -63,6 +68,11 @@ namespace ngraph
class BroadcastLike : public Broadcast
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Broadcast arg to the same shape as like_arg.
BroadcastLike() = default;
/// \brief Broadcast arg to the same shape as like_arg.
///
/// Once the shape of like_arg is known, this op will be replaced with an equivalent
......@@ -72,8 +82,8 @@ namespace ngraph
/// \param like_arg Provides the shape for the result.
/// \param initial_broadcast_axes indicates which axes will be broadcast. If empty,
/// arg must be scalar and all axes are broadcast.
BroadcastLike(const std::shared_ptr<Node>& arg,
const std::shared_ptr<Node>& like_arg,
BroadcastLike(const Output<Node>& arg,
const Output<Node>& like_arg,
const AxisSet& initial_broadcast_axes);
virtual std::shared_ptr<Node>
......@@ -81,6 +91,11 @@ namespace ngraph
void infer_shape() override;
const AxisSet& get_initial_broadcast_axes() const { return m_initial_broadcast_axes; }
void set_initial_broadcast_axes(const AxisSet& initial_broadcast_axes)
{
m_initial_broadcast_axes = initial_broadcast_axes;
}
protected:
AxisSet m_initial_broadcast_axes;
};
......
......@@ -19,8 +19,10 @@
using namespace std;
using namespace ngraph;
op::BroadcastDistributed::BroadcastDistributed(const shared_ptr<Node>& arg, int root_id)
: Op("BroadcastDistributed", check_single_output_args({arg}))
const string op::BroadcastDistributed::type_name{"BroadcastDistributed"};
op::BroadcastDistributed::BroadcastDistributed(const Output<Node>& arg, int root_id)
: Op({arg})
, m_root_id(root_id)
{
constructor_validate_and_infer_types();
......@@ -49,3 +51,8 @@ int op::BroadcastDistributed::get_root_id() const
{
return m_root_id;
}
void op::BroadcastDistributed::set_root_id(int root_id)
{
m_root_id = root_id;
}
......@@ -27,16 +27,21 @@ namespace ngraph
class BroadcastDistributed : public Op
{
public:
BroadcastDistributed(const std::shared_ptr<Node>& arg, int root_id = 0);
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
BroadcastDistributed() = default;
BroadcastDistributed(const Output<Node>& arg, int root_id = 0);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
int get_root_id() const;
void set_root_id(int root_id);
private:
const int m_root_id;
int m_root_id;
};
}
}
......@@ -19,8 +19,10 @@
using namespace std;
using namespace ngraph;
op::Ceiling::Ceiling(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Ceiling", arg)
const string op::Ceiling::type_name{"Ceiling"};
op::Ceiling::Ceiling(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
constructor_validate_and_infer_types();
}
......
......@@ -26,10 +26,15 @@ namespace ngraph
class Ceiling : public util::UnaryElementwiseArithmetic
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a ceiling operation.
Ceiling() = default;
/// \brief Constructs a ceiling operation.
///
/// \param arg Node that produces the input tensor.
Ceiling(const std::shared_ptr<Node>& arg);
Ceiling(const Output<Node>& arg);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -22,13 +22,20 @@
using namespace std;
using namespace ngraph;
op::Concat::Concat(const NodeVector& args, size_t concatenation_axis)
: Op("Concat", check_single_output_args(args))
const string op::Concat::type_name{"Concat"};
op::Concat::Concat(const OutputVector& args, size_t concatenation_axis)
: Op(args)
, m_concatenation_axis(concatenation_axis)
{
constructor_validate_and_infer_types();
}
op::Concat::Concat(const NodeVector& args, size_t concatenation_axis)
: Concat(as_output_vector(args), concatenation_axis)
{
}
void op::Concat::validate_and_infer_types()
{
NODE_VALIDATION_CHECK(this, get_input_size() >= 1, "At least one argument required.");
......
......@@ -28,6 +28,17 @@ namespace ngraph
class Concat : public Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a concatenation operation.
Concat() = default;
/// \brief Constructs a concatenation operation.
///
/// \param args The outputs producing the input tensors.
/// \param concatenation_axis The axis along which to concatenate the input tensors.
Concat(const OutputVector& args, size_t concatenation_axis);
/// \brief Constructs a concatenation operation.
///
/// \param args The nodes producing the input tensors.
......@@ -41,10 +52,15 @@ namespace ngraph
/// \return The concatenation axis.
size_t get_concatenation_axis() const { return m_concatenation_axis; }
void set_concatenation_axis(size_t concatenation_axis)
{
m_concatenation_axis = concatenation_axis;
}
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
const size_t m_concatenation_axis;
size_t m_concatenation_axis;
};
}
}
......@@ -45,6 +45,8 @@ string to_cpp_string(T value)
return rc;
}
const string op::Constant::type_name{"Constant"};
op::Constant::~Constant()
{
}
......
......@@ -34,6 +34,9 @@ namespace ngraph
class Constant : public Node
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a tensor constant.
///
/// \param type The element type of the tensor constant.
......@@ -78,7 +81,7 @@ namespace ngraph
/// \param shape The shape of the tensor constant.
/// \param values A list of string values to use as the constant data.
Constant(const element::Type& type, Shape shape, const std::vector<std::string>& values)
: Node("Constant", {})
: Node({})
, m_element_type(type)
, m_shape(shape)
, m_data(new runtime::AlignedBuffer(shape_size(m_shape) * m_element_type.size(),
......@@ -135,7 +138,7 @@ namespace ngraph
/// \param shape The shape of the tensor constant.
/// \param data A void* to constant data.
Constant(const element::Type& type, const Shape& shape, const void* data)
: Node("Constant", {})
: Node({})
, m_element_type(type)
, m_shape(shape)
, m_data(nullptr)
......
......@@ -21,8 +21,10 @@
using namespace std;
using namespace ngraph;
op::Convert::Convert(const shared_ptr<Node>& arg, const element::Type& element_type)
: Op("Convert", check_single_output_args({arg}))
const string op::Convert::type_name{"Convert"};
op::Convert::Convert(const Output<Node>& arg, const element::Type& element_type)
: Op({arg})
, m_element_type(element_type)
{
constructor_validate_and_infer_types();
......
......@@ -26,11 +26,16 @@ namespace ngraph
class Convert : public Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a conversion operation.
Convert() = default;
/// \brief Constructs a conversion operation.
///
/// \param arg Node that produces the input tensor.
/// \param element_type Element type for the output tensor.
Convert(const std::shared_ptr<Node>& arg, const ngraph::element::Type& element_type);
Convert(const Output<Node>& arg, const ngraph::element::Type& element_type);
void validate_and_infer_types() override;
......@@ -38,8 +43,13 @@ namespace ngraph
copy_with_new_args(const NodeVector& new_args) const override;
const element::Type& get_convert_element_type() const { return m_element_type; }
void set_convert_element_type(const element::Type& element_type)
{
m_element_type = element_type;
}
protected:
const ngraph::element::Type m_element_type;
ngraph::element::Type m_element_type;
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
......
......@@ -27,15 +27,17 @@
using namespace std;
using namespace ngraph;
op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
const shared_ptr<Node>& filters,
const string op::Convolution::type_name{"Convolution"};
op::Convolution::Convolution(const Output<Node>& data_batch,
const Output<Node>& filters,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
const CoordinateDiff& padding_above,
const Strides& data_dilation_strides,
const PadType& pad_type)
: Op("Convolution", check_single_output_args({data_batch, filters}))
: Op({data_batch, filters})
, m_window_movement_strides(window_movement_strides)
, m_window_dilation_strides(window_dilation_strides)
, m_padding_below(padding_below)
......@@ -114,8 +116,8 @@ void op::Convolution::validate_and_infer_types()
set_output_type(0, result_et, result_shape);
}
op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
const shared_ptr<Node>& filters,
op::Convolution::Convolution(const Output<Node>& data_batch,
const Output<Node>& filters,
const Strides& window_movement_strides,
const Strides& window_dilation_strides,
const CoordinateDiff& padding_below,
......@@ -130,8 +132,8 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
{
}
op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
const shared_ptr<Node>& filters,
op::Convolution::Convolution(const Output<Node>& data_batch,
const Output<Node>& filters,
const Strides& window_movement_strides,
const Strides& window_dilation_strides)
: Convolution(data_batch,
......@@ -143,8 +145,8 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
{
}
op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
const shared_ptr<Node>& filters,
op::Convolution::Convolution(const Output<Node>& data_batch,
const Output<Node>& filters,
const Strides& window_movement_strides)
: Convolution(data_batch,
filters,
......@@ -155,7 +157,7 @@ op::Convolution::Convolution(const shared_ptr<Node>& data_batch,
{
}
op::Convolution::Convolution(const shared_ptr<Node>& data_batch, const shared_ptr<Node>& filters)
op::Convolution::Convolution(const Output<Node>& data_batch, const Output<Node>& filters)
: Convolution(data_batch, filters, Strides(), Strides(), CoordinateDiff(), CoordinateDiff())
{
}
......@@ -204,15 +206,17 @@ void op::Convolution::generate_adjoints(autodiff::Adjoints& adjoints, const Node
m_data_dilation_strides));
}
const string op::ConvolutionBackpropData::type_name{"ConvolutionBackpropData"};
op::ConvolutionBackpropData::ConvolutionBackpropData(const Shape& data_batch_shape,
const shared_ptr<Node>& filters,
const shared_ptr<Node>& output_delta,
const Output<Node>& filters,
const Output<Node>& output_delta,
const Strides& window_movement_strides_forward,
const Strides& window_dilation_strides_forward,
const CoordinateDiff& padding_below_forward,
const CoordinateDiff& padding_above_forward,
const Strides& data_dilation_strides_forward)
: Op("ConvolutionBackpropData", check_single_output_args({filters, output_delta}))
: Op({filters, output_delta})
, m_data_batch_shape(data_batch_shape)
, m_window_movement_strides_forward(window_movement_strides_forward)
, m_window_dilation_strides_forward(window_dilation_strides_forward)
......@@ -332,14 +336,14 @@ void op::ConvolutionBackpropData::generate_adjoints(autodiff::Adjoints& adjoints
m_data_dilation_strides_forward[i]);
}
auto swap_NC = [](const shared_ptr<Node> n) {
AxisVector ax_order = ngraph::get_default_order(n->get_shape());
auto swap_NC = [](const Output<Node>& n) {
AxisVector ax_order = ngraph::get_default_order(n.get_shape());
ax_order[0] = 1;
ax_order[1] = 0;
auto new_shape = n->get_shape();
new_shape[0] = n->get_shape()[1];
new_shape[1] = n->get_shape()[0];
auto new_shape = n.get_shape();
new_shape[0] = n.get_shape()[1];
new_shape[1] = n.get_shape()[0];
return make_shared<op::Reshape>(n, ax_order, new_shape);
};
......@@ -422,16 +426,18 @@ CoordinateDiff op::ConvolutionBackpropData::compute_backward_delta_out_pad_above
return backward_delta_out_pad_above;
}
const string op::ConvolutionBackpropFilters::type_name{"ConvolutionBackpropFilters"};
op::ConvolutionBackpropFilters::ConvolutionBackpropFilters(
const shared_ptr<Node>& data_batch,
const Output<Node>& data_batch,
const Shape& filters_shape,
const shared_ptr<Node>& output_delta,
const Output<Node>& output_delta,
const Strides& window_movement_strides_forward,
const Strides& window_dilation_strides_forward,
const CoordinateDiff& padding_below_forward,
const CoordinateDiff& padding_above_forward,
const Strides& data_dilation_strides_forward)
: Op("ConvolutionBackpropFilters", check_single_output_args({data_batch, output_delta}))
: Op({data_batch, output_delta})
, m_filters_shape(filters_shape)
, m_window_movement_strides_forward(window_movement_strides_forward)
, m_window_dilation_strides_forward(window_dilation_strides_forward)
......
This diff is collapsed.
......@@ -22,8 +22,10 @@
using namespace std;
using namespace ngraph;
op::Cos::Cos(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Cos", arg)
const string op::Cos::type_name{"Cos"};
op::Cos::Cos(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
constructor_validate_and_infer_types();
}
......
......@@ -26,10 +26,15 @@ namespace ngraph
class Cos : public util::UnaryElementwiseArithmetic
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a cosine operation.
Cos() = default;
/// \brief Constructs a cosine operation.
///
/// \param arg Node that produces the input tensor.
Cos(const std::shared_ptr<Node>& arg);
Cos(const Output<Node>& arg);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -21,8 +21,10 @@
using namespace std;
using namespace ngraph;
op::Cosh::Cosh(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Cosh", arg)
const string op::Cosh::type_name{"Cosh"};
op::Cosh::Cosh(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
constructor_validate_and_infer_types();
}
......
......@@ -26,10 +26,15 @@ namespace ngraph
class Cosh : public util::UnaryElementwiseArithmetic
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a hyperbolic cosine operation.
Cosh() = default;
/// \brief Constructs a hyperbolic cosine operation.
///
/// \param arg Node that produces the input tensor.
Cosh(const std::shared_ptr<Node>& arg);
Cosh(const Output<Node>& arg);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -20,13 +20,15 @@
using namespace std;
using namespace ngraph;
op::Dequantize::Dequantize(const shared_ptr<Node>& input,
const shared_ptr<Node>& scale,
const shared_ptr<Node>& zero_point,
const string op::Dequantize::type_name{"Dequantize"};
op::Dequantize::Dequantize(const Output<Node>& input,
const Output<Node>& scale,
const Output<Node>& zero_point,
const element::Type& type,
const AxisSet& axes)
: Op("Dequantize", check_single_output_args({input, scale, zero_point}))
: Op({input, scale, zero_point})
, m_type(type)
, m_axes(axes)
{
......
......@@ -30,31 +30,40 @@ namespace ngraph
class Dequantize : public ngraph::op::Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a Dequantize operation
Dequantize() = default;
/// \brief Constructs a Dequantize operation
/// \param input quantized input
/// \param scale scale used for mapping
/// \param zero_point zero point used for mapping
/// \param type output element type
/// \param axes axis positions on which `scale` and `zero_point` are specified
Dequantize(const std::shared_ptr<Node>& input,
const std::shared_ptr<Node>& scale,
const std::shared_ptr<Node>& zero_point,
const ngraph::element::Type& type,
const ngraph::AxisSet& axes);
Dequantize(const Output<Node>& input,
const Output<Node>& scale,
const Output<Node>& zero_point,
const element::Type& type,
const AxisSet& axes);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
const ngraph::AxisSet& get_axes() const { return m_axes; }
const AxisSet& get_axes() const { return m_axes; }
void set_axes(const AxisSet& axes) { m_axes = axes; }
const element::Type& get_type() const { return m_type; }
void set_type(const element::Type& type) { m_type = type; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
private:
ngraph::element::Type m_type;
ngraph::AxisSet m_axes;
element::Type m_type;
AxisSet m_axes;
};
}
}
......@@ -21,20 +21,21 @@
using namespace std;
using namespace ngraph;
op::Divide::Divide(const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1,
const string op::Divide::type_name{"Divide"};
op::Divide::Divide(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic("Divide", arg0, arg1, autob)
, m_pythondiv(true)
: BinaryElementwiseArithmetic(arg0, arg1, autob)
{
constructor_validate_and_infer_types();
}
op::Divide::Divide(const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1,
op::Divide::Divide(const Output<Node>& arg0,
const Output<Node>& arg1,
bool pythondiv,
const AutoBroadcastSpec& autob)
: BinaryElementwiseArithmetic("Divide", arg0, arg1, autob)
: BinaryElementwiseArithmetic(arg0, arg1, autob)
, m_pythondiv(pythondiv)
{
constructor_validate_and_infer_types();
......@@ -63,7 +64,7 @@ void op::Divide::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto
adjoints.add_delta(y, -delta * shared_from_this() / y);
}
shared_ptr<Node> ngraph::operator/(const shared_ptr<Node> arg0, const shared_ptr<Node> arg1)
shared_ptr<Node> ngraph::operator/(const Output<Node> arg0, const Output<Node> arg1)
{
return make_shared<op::Divide>(arg0, arg1);
}
......@@ -26,14 +26,19 @@ namespace ngraph
class Divide : public util::BinaryElementwiseArithmetic
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a division operation.
Divide() = default;
/// \brief Constructs a division operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param pythondiv Use Python style rounding for integral type
/// \param autob Auto broadcast specification
Divide(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
Divide(const Output<Node>& arg0,
const Output<Node>& arg1,
bool pythondiv,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
......@@ -42,11 +47,12 @@ namespace ngraph
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification
Divide(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
Divide(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
bool is_pythondiv() const { return m_pythondiv; }
void set_is_pythondiv(bool pythondiv) { m_pythondiv = pythondiv; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......@@ -54,10 +60,10 @@ namespace ngraph
const NodeVector& deltas) override;
protected:
bool m_pythondiv;
bool m_pythondiv{true};
};
}
std::shared_ptr<ngraph::Node> operator/(const std::shared_ptr<ngraph::Node> arg0,
const std::shared_ptr<ngraph::Node> arg1);
std::shared_ptr<ngraph::Node> operator/(const Output<ngraph::Node> arg0,
const Output<ngraph::Node> arg1);
}
......@@ -29,16 +29,18 @@
using namespace std;
using namespace ngraph;
op::Dot::Dot(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
const string op::Dot::type_name{"Dot"};
op::Dot::Dot(const Output<Node>& arg0, const Output<Node>& arg1)
: Dot(arg0, arg1, 0, false)
{
}
op::Dot::Dot(const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1,
op::Dot::Dot(const Output<Node>& arg0,
const Output<Node>& arg1,
size_t reduction_axes_count,
bool has_reduction_axes_count)
: Op("Dot", check_single_output_args({arg0, arg1}))
: Op({arg0, arg1})
, m_reduction_axes_count(reduction_axes_count)
, m_has_reduction_axes_count(has_reduction_axes_count)
{
......@@ -154,7 +156,7 @@ void op::Dot::validate_and_infer_types()
set_output_type(0, result_et, result_shape);
}
shared_ptr<op::Reshape> make_reshape_axes_to_front(const shared_ptr<Node>& n,
shared_ptr<op::Reshape> make_reshape_axes_to_front(const Output<Node>& n,
const Shape& front_shape,
const Shape& back_shape)
{
......
......@@ -28,13 +28,18 @@ namespace ngraph
class Dot : public Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a dot product operation.
Dot() = default;
/// \brief Constructs a dot product operation.
///
/// \param arg0 The node producing the first argument.
/// \param arg1 The node producing the second argument.
/// \param reduction_axes_count The number of axes to dot.
Dot(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
Dot(const Output<Node>& arg0,
const Output<Node>& arg1,
size_t reduction_axes_count,
bool has_reduction_axes_count = true);
......@@ -48,11 +53,20 @@ namespace ngraph
///
/// \param arg0 The node producing the first argument.
/// \param arg1 The node producing the second argument.
Dot(const std::shared_ptr<Node>& arg0, const std::shared_ptr<Node>& arg1);
Dot(const Output<Node>& arg0, const Output<Node>& arg1);
void validate_and_infer_types() override;
size_t get_reduction_axes_count() const { return m_reduction_axes_count; }
void get_reduction_axes_count(size_t reduction_axes_count)
{
m_reduction_axes_count = reduction_axes_count;
}
bool get_has_reduction_axes_count() const { return m_has_reduction_axes_count; }
void set_has_reduction_axes_count(bool has_reduction_axes_count)
{
m_has_reduction_axes_count = has_reduction_axes_count;
}
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override
{
......
......@@ -19,6 +19,8 @@
using namespace std;
using namespace ngraph;
const string op::EmbeddingLookup::type_name{"EmbeddingLookup"};
void op::EmbeddingLookup::validate_and_infer_types()
{
element::Type result_et = get_input_element_type(1);
......
......@@ -28,6 +28,11 @@ namespace ngraph
class EmbeddingLookup : public Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a EmbeddingLookup operation.
EmbeddingLookup() = default;
/// \brief Constructs a EmbeddingLookup operation.
///
/// EmbeddingLookup constructs an output tensor by replacing every index in a given input tensor
......@@ -36,8 +41,8 @@ namespace ngraph
/// \param data The input indices for tokens to be translated into embeddings
/// \param weights is a dense matrix [N,M] where each row 0..N
/// corresponds to an embedding (i.e. typically, a vector of real numbers) of length M
EmbeddingLookup(const std::shared_ptr<Node>& data, const std::shared_ptr<Node>& weights)
: Op("EmbeddingLookup", check_single_output_args({data, weights}))
EmbeddingLookup(const Output<Node>& data, const Output<Node>& weights)
: Op({data, weights})
{
constructor_validate_and_infer_types();
}
......
......@@ -19,10 +19,10 @@
using namespace std;
using namespace ngraph;
op::Equal::Equal(const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1,
const AutoBroadcastSpec& autob)
: BinaryElementwiseComparison("Equal", arg0, arg1, autob)
const string op::Equal::type_name{"Equal"};
op::Equal::Equal(const Output<Node>& arg0, const Output<Node>& arg1, const AutoBroadcastSpec& autob)
: BinaryElementwiseComparison(arg0, arg1, autob)
{
constructor_validate_and_infer_types();
}
......
......@@ -40,13 +40,18 @@ namespace ngraph
class Equal : public util::BinaryElementwiseComparison
{
public:
/// \brief Constructs an is-equal operation.
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an equal operation.
Equal() = default;
/// \brief Constructs an equal operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification
Equal(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
Equal(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node>
......
......@@ -21,14 +21,16 @@
using namespace std;
using namespace ngraph;
const string op::Erf::type_name{"Erf"};
shared_ptr<Node> op::Erf::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Erf>(new_args.at(0));
}
op::Erf::Erf(shared_ptr<Node> arg)
: UnaryElementwiseArithmetic("Erf", arg)
op::Erf::Erf(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
constructor_validate_and_infer_types();
}
......@@ -27,7 +27,11 @@ namespace ngraph
class Erf : public util::UnaryElementwiseArithmetic
{
public:
Erf(std::shared_ptr<Node> arg);
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
Erf() = default;
Erf(const Output<Node>& arg);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -20,8 +20,10 @@
using namespace std;
using namespace ngraph;
op::Exp::Exp(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Exp", arg)
const string op::Exp::type_name{"Exp"};
op::Exp::Exp(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
constructor_validate_and_infer_types();
}
......
......@@ -26,10 +26,15 @@ namespace ngraph
class Exp : public util::UnaryElementwiseArithmetic
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs an exponential operation.
Exp() = default;
/// \brief Constructs an exponential operation.
///
/// \param arg Node that produces the input tensor.
Exp(const std::shared_ptr<Node>& arg);
Exp(const Output<Node>& arg);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -19,8 +19,10 @@
using namespace std;
using namespace ngraph;
op::Floor::Floor(const shared_ptr<Node>& arg)
: UnaryElementwiseArithmetic("Floor", arg)
const string op::Floor::type_name{"Floor"};
op::Floor::Floor(const Output<Node>& arg)
: UnaryElementwiseArithmetic(arg)
{
constructor_validate_and_infer_types();
}
......
......@@ -26,10 +26,15 @@ namespace ngraph
class Floor : public util::UnaryElementwiseArithmetic
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a floor operation.
Floor() = default;
/// \brief Constructs a floor operation.
///
/// \param arg Node that produces the input tensor.
Floor(const std::shared_ptr<Node>& arg);
Floor(const Output<Node>& arg);
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -24,10 +24,12 @@
using namespace std;
using namespace ngraph;
op::Reshape::Reshape(const shared_ptr<Node>& arg,
const string op::Reshape::type_name{"Reshape"};
op::Reshape::Reshape(const Output<Node>& arg,
const AxisVector& input_order,
const Shape& output_shape)
: Op("Reshape", check_single_output_args({arg}))
: Op({arg})
, m_input_order(input_order)
, m_output_shape(output_shape)
{
......
......@@ -60,6 +60,11 @@ namespace ngraph
class Reshape : public Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a reshape operation.
Reshape() = default;
/// \brief Constructs a reshape operation.
///
/// \param arg The tensor to be reshaped.
......@@ -67,7 +72,7 @@ namespace ngraph
/// sequence \f$(0,\dots,n-1)\f$ where \f$n\f$ is the rank of the input tensor.
/// \param output_shape The output shape. If the input shape is \f$(a_0,\dots,a_{k-1})\f$ then the output shape must
/// be of the form \f$(b_0,\dots,b_{j-1})\f$ where \f$\Pi(a_i) = \Pi(b_i)\f$.
Reshape(const std::shared_ptr<Node>& arg,
Reshape(const Output<Node>& arg,
const AxisVector& input_order,
const Shape& output_shape);
......@@ -78,15 +83,18 @@ namespace ngraph
/// \return The order in which to iterate over input axes.
const AxisVector& get_input_order() const { return m_input_order; }
void set_input_order(const AxisVector& input_order) { m_input_order = input_order; }
/// \return The shape of the output tensor.
const Shape& get_output_shape() const { return m_output_shape; }
void set_output_shape(const Shape& output_shape) { m_output_shape = output_shape; }
bool get_is_transpose() const { return m_is_transpose; }
void set_is_transpose(bool is_transpose) { m_is_transpose = is_transpose; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
const AxisVector m_input_order;
const Shape m_output_shape;
AxisVector m_input_order;
Shape m_output_shape;
bool m_is_transpose{false};
};
}
......
......@@ -24,8 +24,10 @@
using namespace std;
using namespace ngraph;
op::Result::Result(const shared_ptr<Node>& arg, bool needs_default_layout)
: Op("Result", check_single_output_args({arg}))
const string op::Result::type_name{"Result"};
op::Result::Result(const Output<Node>& arg, bool needs_default_layout)
: Op({arg})
, m_needs_default_layout(needs_default_layout)
{
constructor_validate_and_infer_types();
......
......@@ -27,10 +27,15 @@ namespace ngraph
class Result : public Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Allows a value to be used as a function result.
Result() = default;
/// \brief Allows a value to be used as a function result.
///
/// \param arg Node that produces the input tensor.
Result(const std::shared_ptr<Node>& arg, bool needs_default_layout = false);
Result(const Output<Node>& arg, bool needs_default_layout = false);
void validate_and_infer_types() override;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment