Commit 7fbdfd5c authored by Scott Cyphers's avatar Scott Cyphers

Merge branch 's-barannikov/new_op_form/g_ops' into cyphers/s-barannikov

parents e4955613 43fe0711
......@@ -49,7 +49,7 @@ void op::Add::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector&
adjoints.add_delta(y, delta);
}
shared_ptr<Node> ngraph::operator+(const shared_ptr<Node>& arg0, const shared_ptr<Node>& arg1)
shared_ptr<Node> ngraph::operator+(const Output<Node>& arg0, const Output<Node>& arg1)
{
return make_shared<op::Add>(arg0, arg1);
}
......@@ -58,6 +58,5 @@ namespace ngraph
};
}
std::shared_ptr<ngraph::Node> operator+(const std::shared_ptr<ngraph::Node>& arg0,
const std::shared_ptr<ngraph::Node>& arg1);
std::shared_ptr<Node> operator+(const Output<Node>& arg0, const Output<Node>& arg1);
}
......@@ -64,7 +64,7 @@ void op::Divide::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto
adjoints.add_delta(y, -delta * shared_from_this() / y);
}
shared_ptr<Node> ngraph::operator/(const Output<Node> arg0, const Output<Node> arg1)
shared_ptr<Node> ngraph::operator/(const Output<Node>& arg0, const Output<Node>& arg1)
{
return make_shared<op::Divide>(arg0, arg1);
}
......@@ -64,6 +64,5 @@ namespace ngraph
};
}
std::shared_ptr<ngraph::Node> operator/(const Output<ngraph::Node> arg0,
const Output<ngraph::Node> arg1);
std::shared_ptr<Node> operator/(const Output<Node>& arg0, const Output<Node>& arg1);
}
......@@ -58,7 +58,7 @@ namespace ngraph
void validate_and_infer_types() override;
size_t get_reduction_axes_count() const { return m_reduction_axes_count; }
void get_reduction_axes_count(size_t reduction_axes_count)
void set_reduction_axes_count(size_t reduction_axes_count)
{
m_reduction_axes_count = reduction_axes_count;
}
......
......@@ -23,6 +23,8 @@ using namespace ngraph;
static int PARAMS = 0;
static int INDICES = 1;
const string op::Gather::type_name{"Gather"};
shared_ptr<Node> op::Gather::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
......
......@@ -26,13 +26,15 @@ namespace ngraph
class Gather : public Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
Gather() = default;
/// \param params The tensor from which slices are gathered
/// \param indices Index tensor: Data type must be `element::i32` or `element::i64`
/// \param axis Axis in params to gather
Gather(const std::shared_ptr<Node>& params,
const std::shared_ptr<Node>& indices,
size_t axis = 0)
: Op("Gather", check_single_output_args({params, indices}))
Gather(const Output<Node>& params, const Output<Node>& indices, size_t axis = 0)
: Op({params, indices})
, m_axis(axis)
{
constructor_validate_and_infer_types();
......@@ -46,6 +48,7 @@ namespace ngraph
}
size_t get_axis() const { return m_axis; }
void set_axis(size_t axis) { m_axis = axis; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -23,6 +23,8 @@ using namespace ngraph;
static int PARAMS = 0;
static int INDICES = 1;
const string op::GatherND::type_name{"GatherND"};
shared_ptr<Node> op::GatherND::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
......
......@@ -26,10 +26,14 @@ namespace ngraph
class GatherND : public Op
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
GatherND() = default;
/// \param params The tensor from which slices are gathered
/// \param indices Index tensor: Data type must be `element::i32` or `element::i64`
GatherND(const std::shared_ptr<Node>& params, const std::shared_ptr<Node>& indices)
: Op("GatherND", check_single_output_args({params, indices}))
GatherND(const Output<Node>& params, const Output<Node>& indices)
: Op({params, indices})
{
constructor_validate_and_infer_types();
}
......
......@@ -19,10 +19,12 @@
using namespace std;
using namespace ngraph;
op::Greater::Greater(const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1,
const string op::Greater::type_name{"Greater"};
op::Greater::Greater(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob)
: BinaryElementwiseComparison("Greater", arg0, arg1, autob)
: BinaryElementwiseComparison(arg0, arg1, autob)
{
constructor_validate_and_infer_types();
}
......
......@@ -26,13 +26,18 @@ namespace ngraph
class Greater : public util::BinaryElementwiseComparison
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a greater-than operation.
Greater() = default;
/// \brief Constructs a greater-than operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification
Greater(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
Greater(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node>
......
......@@ -19,10 +19,12 @@
using namespace std;
using namespace ngraph;
op::GreaterEq::GreaterEq(const shared_ptr<Node>& arg0,
const shared_ptr<Node>& arg1,
const string op::GreaterEq::type_name{"GreaterEq"};
op::GreaterEq::GreaterEq(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob)
: BinaryElementwiseComparison("GreaterEq", arg0, arg1, autob)
: BinaryElementwiseComparison(arg0, arg1, autob)
{
constructor_validate_and_infer_types();
}
......
......@@ -26,13 +26,18 @@ namespace ngraph
class GreaterEq : public util::BinaryElementwiseComparison
{
public:
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
/// \brief Constructs a greater-than-or-equal operation.
GreaterEq() = default;
/// \brief Constructs a greater-than-or-equal operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param autob Auto broadcast specification
GreaterEq(const std::shared_ptr<Node>& arg0,
const std::shared_ptr<Node>& arg1,
GreaterEq(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& autob = AutoBroadcastSpec());
virtual std::shared_ptr<Node>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment