Commit 7fbdfd5c authored by Scott Cyphers's avatar Scott Cyphers

Merge branch 's-barannikov/new_op_form/g_ops' into cyphers/s-barannikov

parents e4955613 43fe0711
...@@ -49,7 +49,7 @@ void op::Add::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& ...@@ -49,7 +49,7 @@ void op::Add::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
adjoints.add_delta(y, delta); adjoints.add_delta(y, delta);
} }
shared_ptr<Node> ngraph::operator+(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1) shared_ptr<Node> ngraph::operator+(const Output<Node>& arg0, const Output<Node>& arg1)
{ {
return make_shared<op::Add>(arg0, arg1); return make_shared<op::Add>(arg0, arg1);
} }
...@@ -58,6 +58,5 @@ namespace ngraph ...@@ -58,6 +58,5 @@ namespace ngraph
}; };
} }
std::shared_ptr<ngraph::Node> operator+(const std::shared_ptr<ngraph::Node>& arg0, std::shared_ptr<Node> operator+(const Output<Node>& arg0, const Output<Node>& arg1);
const std::shared_ptr<ngraph::Node>& arg1);
} }
...@@ -64,7 +64,7 @@ void op::Divide::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto ...@@ -64,7 +64,7 @@ void op::Divide::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto
adjoints.add_delta(y, -delta * shared_from_this() / y); adjoints.add_delta(y, -delta * shared_from_this() / y);
} }
shared_ptr<Node> ngraph::operator/(const Output<Node> arg0, const Output<Node> arg1) shared_ptr<Node> ngraph::operator/(const Output<Node>& arg0, const Output<Node>& arg1)
{ {
return make_shared<op::Divide>(arg0, arg1); return make_shared<op::Divide>(arg0, arg1);
} }
...@@ -64,6 +64,5 @@ namespace ngraph ...@@ -64,6 +64,5 @@ namespace ngraph
}; };
} }
std::shared_ptr<ngraph::Node> operator/(const Output<ngraph::Node> arg0, std::shared_ptr<Node> operator/(const Output<Node>& arg0, const Output<Node>& arg1);
const Output<ngraph::Node> arg1);
} }
...@@ -58,7 +58,7 @@ namespace ngraph ...@@ -58,7 +58,7 @@ namespace ngraph
void validate_and_infer_types() override; void validate_and_infer_types() override;
size_t get_reduction_axes_count() const { return m_reduction_axes_count; } size_t get_reduction_axes_count() const { return m_reduction_axes_count; }
void get_reduction_axes_count(size_t reduction_axes_count) void set_reduction_axes_count(size_t reduction_axes_count)
{ {
m_reduction_axes_count = reduction_axes_count; m_reduction_axes_count = reduction_axes_count;
} }
......
...@@ -23,6 +23,8 @@ using namespace ngraph; ...@@ -23,6 +23,8 @@ using namespace ngraph;
static int PARAMS = 0; static int PARAMS = 0;
static int INDICES = 1; static int INDICES = 1;
const string op::Gather::type_name{"Gather"};
shared_ptr<Node> op::Gather::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Gather::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -26,13 +26,15 @@ namespace ngraph ...@@ -26,13 +26,15 @@ namespace ngraph
class Gather : public Op class Gather : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
Gather() = default;
/// \param params The tensor from which slices are gathered /// \param params The tensor from which slices are gathered
/// \param indices Index tensor: Data type must be `element::i32` or `element::i64` /// \param indices Index tensor: Data type must be `element::i32` or `element::i64`
/// \param axis Axis in params to gather /// \param axis Axis in params to gather
Gather(const std::shared_ptr<Node>& params, Gather(const Output<Node>& params, const Output<Node>& indices, size_t axis = 0)
const std::shared_ptr<Node>& indices, : Op({params, indices})
size_t axis = 0)
: Op("Gather", check_single_output_args({params, indices}))
, m_axis(axis) , m_axis(axis)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
...@@ -46,6 +48,7 @@ namespace ngraph ...@@ -46,6 +48,7 @@ namespace ngraph
} }
size_t get_axis() const { return m_axis; } size_t get_axis() const { return m_axis; }
void set_axis(size_t axis) { m_axis = axis; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -23,6 +23,8 @@ using namespace ngraph; ...@@ -23,6 +23,8 @@ using namespace ngraph;
static int PARAMS = 0; static int PARAMS = 0;
static int INDICES = 1; static int INDICES = 1;
const string op::GatherND::type_name{"GatherND"};
shared_ptr<Node> op::GatherND::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::GatherND::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -26,10 +26,14 @@ namespace ngraph ...@@ -26,10 +26,14 @@ namespace ngraph
class GatherND : public Op class GatherND : public Op
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
GatherND() = default;
/// \param params The tensor from which slices are gathered /// \param params The tensor from which slices are gathered
/// \param indices Index tensor: Data type must be `element::i32` or `element::i64` /// \param indices Index tensor: Data type must be `element::i32` or `element::i64`
GatherND(const std::shared_ptr<Node>& params, const std::shared_ptr<Node>& indices) GatherND(const Output<Node>& params, const Output<Node>& indices)
: Op("GatherND", check_single_output_args({params, indices})) : Op({params, indices})
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -19,10 +19,12 @@ ...@@ -19,10 +19,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::Greater::Greater(const shared_ptr<Node>& arg0, const string op::Greater::type_name{"Greater"};
const shared_ptr<Node>& arg1,
op::Greater::Greater(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob) const AutoBroadcastSpec& autob)
: BinaryElementwiseComparison("Greater", arg0, arg1, autob) : BinaryElementwiseComparison(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,13 +26,18 @@ namespace ngraph ...@@ -26,13 +26,18 @@ namespace ngraph
class Greater : public util::BinaryElementwiseComparison class Greater : public util::BinaryElementwiseComparison
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a greater-than operation.
Greater() = default;
/// \brief Constructs a greater-than operation. /// \brief Constructs a greater-than operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
Greater(const std::shared_ptr<Node>& arg0, Greater(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -19,10 +19,12 @@ ...@@ -19,10 +19,12 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
op::GreaterEq::GreaterEq(const shared_ptr<Node>& arg0, const string op::GreaterEq::type_name{"GreaterEq"};
const shared_ptr<Node>& arg1,
op::GreaterEq::GreaterEq(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob) const AutoBroadcastSpec& autob)
: BinaryElementwiseComparison("GreaterEq", arg0, arg1, autob) : BinaryElementwiseComparison(arg0, arg1, autob)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
......
...@@ -26,13 +26,18 @@ namespace ngraph ...@@ -26,13 +26,18 @@ namespace ngraph
class GreaterEq : public util::BinaryElementwiseComparison class GreaterEq : public util::BinaryElementwiseComparison
{ {
public: public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a greater-than-or-equal operation.
GreaterEq() = default;
/// \brief Constructs a greater-than-or-equal operation. /// \brief Constructs a greater-than-or-equal operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification /// \param autob Auto broadcast specification
GreaterEq(const std::shared_ptr<Node>& arg0, GreaterEq(const Output<Node>& arg0,
const std::shared_ptr<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec()); const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment