Commit f5b322cf authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Scott Cyphers

[SPEC] Numpy AutoBroadcast as default for specific ops. (#3816)

* Add upgrade/downgrade pass for Add op.

* Add upgrade/downgrade pass for Divide op.

- Change default value of autobradcasting in v1 into NUMPY.

* Add v1 version for Equal operator.

* Rename helper functions to fix compiler errros.

Fix Divide op version.

* Add downgrade and upgrade passes for Equal op.

* Reformat test cases. Add helper functions.

Add UT for Equal op.

* Add upgrade/downgrade pass and UT for Greater op.

* Add upgrade/downgrade pass and UT for GreaterEq op.

* Add upgrade/downgrade pass and UT for Less op.

* Add upgrade/downgrade pass and UT for LessEq op.

* Add upgrade/downgrade pass and UT for Maximum op.

* Add upgrade/downgrade pass and UT for Minimum op.

* Add upgrade/downgrade pass and UT for Multiply op.

* Add upgrade/downgrade pass and UT for NotEqual op.

* Add upgrade/downgrade pass and UT for Power op.

* Force ops version 1.

* Don't inline templates.

* Fix namespaces and some formatting.

* Update ONNX Importer to produce v1 nGraph nodes.

* Fix function return type.

* Fix uninitialized local variable warning.

* Fix confilicting declarations.

* Apply clang-format.

* Fix errors for distributed nGraph with unavailable classes.

* Fix downgrade pass for LessEqual op.
parent 1039ea1a
...@@ -47,10 +47,8 @@ namespace ngraph ...@@ -47,10 +47,8 @@ namespace ngraph
{ {
inline NodeVector add(const Node& node) inline NodeVector add(const Node& node)
{ {
return {std::make_shared<ngraph::op::Add>( return {std::make_shared<ngraph::op::v1::Add>(node.get_ng_inputs().at(0),
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1))};
node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
} }
} // namespace set_7 } // namespace set_7
......
...@@ -47,10 +47,8 @@ namespace ngraph ...@@ -47,10 +47,8 @@ namespace ngraph
{ {
inline NodeVector div(const Node& node) inline NodeVector div(const Node& node)
{ {
return {std::make_shared<ngraph::op::Divide>( return {std::make_shared<ngraph::op::v1::Divide>(node.get_ng_inputs().at(0),
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1))};
node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -31,10 +31,8 @@ namespace ngraph ...@@ -31,10 +31,8 @@ namespace ngraph
{ {
inline NodeVector equal(const Node& node) inline NodeVector equal(const Node& node)
{ {
return {std::make_shared<ngraph::op::Equal>( return {std::make_shared<ngraph::op::v1::Equal>(node.get_ng_inputs().at(0),
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1))};
node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -31,10 +31,8 @@ namespace ngraph ...@@ -31,10 +31,8 @@ namespace ngraph
{ {
inline NodeVector greater(const Node& node) inline NodeVector greater(const Node& node)
{ {
return {std::make_shared<ngraph::op::Greater>( return {std::make_shared<ngraph::op::v1::Greater>(node.get_ng_inputs().at(0),
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1))};
node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -46,9 +46,9 @@ namespace ngraph ...@@ -46,9 +46,9 @@ namespace ngraph
<< " alpha value should be in range (0,1)"; << " alpha value should be in range (0,1)";
std::shared_ptr<ngraph::Node> alpha_node = std::shared_ptr<ngraph::Node> alpha_node =
std::make_shared<ngraph::op::Constant>( std::make_shared<ngraph::op::Constant>(data->get_element_type(),
data->get_element_type(), Shape{}, std::vector<double>{alpha}); data->get_shape(),
alpha_node = ngraph::op::make_broadcast_node(alpha_node, data->get_shape()); std::vector<double>{alpha});
return {std::make_shared<ngraph::op::Maximum>(data * alpha_node, data)}; return {std::make_shared<ngraph::op::Maximum>(data * alpha_node, data)};
} }
......
...@@ -31,10 +31,8 @@ namespace ngraph ...@@ -31,10 +31,8 @@ namespace ngraph
{ {
inline NodeVector less(const Node& node) inline NodeVector less(const Node& node)
{ {
return {std::make_shared<ngraph::op::Less>( return {std::make_shared<ngraph::op::v1::Less>(node.get_ng_inputs().at(0),
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1))};
node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -40,7 +40,7 @@ namespace ngraph ...@@ -40,7 +40,7 @@ namespace ngraph
{ {
inline NodeVector max(const Node& node) inline NodeVector max(const Node& node)
{ {
return variadic::make_ng_variadic_op_with_broadcast<ngraph::op::Maximum>(node); return variadic::make_ng_variadic_op<ngraph::op::v1::Maximum>(node);
} }
} // namespace set_8 } // namespace set_8
......
...@@ -49,8 +49,7 @@ namespace ngraph ...@@ -49,8 +49,7 @@ namespace ngraph
{ {
NodeVector mean(const Node& node) NodeVector mean(const Node& node)
{ {
auto sum = auto sum = variadic::make_ng_variadic_op<ngraph::op::v1::Add>(node).front();
variadic::make_ng_variadic_op_with_broadcast<ngraph::op::Add>(node).front();
auto shape = sum->get_shape(); auto shape = sum->get_shape();
// Create a Constant representing the number of inputs with the same shape as // Create a Constant representing the number of inputs with the same shape as
......
...@@ -40,7 +40,7 @@ namespace ngraph ...@@ -40,7 +40,7 @@ namespace ngraph
{ {
inline NodeVector min(const Node& node) inline NodeVector min(const Node& node)
{ {
return variadic::make_ng_variadic_op_with_broadcast<ngraph::op::Minimum>(node); return variadic::make_ng_variadic_op<ngraph::op::v1::Minimum>(node);
} }
} // namespace set_8 } // namespace set_8
......
...@@ -49,10 +49,8 @@ namespace ngraph ...@@ -49,10 +49,8 @@ namespace ngraph
{ {
inline NodeVector mul(const Node& node) inline NodeVector mul(const Node& node)
{ {
return {std::make_shared<ngraph::op::Multiply>( return {std::make_shared<ngraph::op::v1::Multiply>(node.get_ng_inputs().at(0),
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1))};
node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
} }
} // namespace set_7 } // namespace set_7
......
...@@ -31,10 +31,8 @@ namespace ngraph ...@@ -31,10 +31,8 @@ namespace ngraph
{ {
inline NodeVector pow(const Node& node) inline NodeVector pow(const Node& node)
{ {
return {std::make_shared<ngraph::op::Power>( return {std::make_shared<ngraph::op::v1::Power>(node.get_ng_inputs().at(0),
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1))};
node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
} }
} // namespace set_1 } // namespace set_1
......
...@@ -70,7 +70,7 @@ namespace ngraph ...@@ -70,7 +70,7 @@ namespace ngraph
std::shared_ptr<ngraph::Node> values_below_neg_lambd = std::shared_ptr<ngraph::Node> values_below_neg_lambd =
std::make_shared<ngraph::op::Less>(input, negative_lambd); std::make_shared<ngraph::op::Less>(input, negative_lambd);
std::shared_ptr<ngraph::Node> values_above_pos_lambd = std::shared_ptr<ngraph::Node> values_above_pos_lambd =
std::make_shared<ngraph::op::Greater>(input, positive_lambd); std::make_shared<ngraph::op::v1::Greater>(input, positive_lambd);
// Convert from bool to the input type to be able to multiply adjusted inputs // Convert from bool to the input type to be able to multiply adjusted inputs
// by the created masks // by the created masks
......
...@@ -16,9 +16,9 @@ ...@@ -16,9 +16,9 @@
#include <memory> #include <memory>
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/exp.hpp" #include "ngraph/op/exp.hpp"
#include "ngraph/op/greater.hpp" #include "ngraph/op/greater.hpp"
#include "ngraph/op/log.hpp" #include "ngraph/op/log.hpp"
...@@ -40,9 +40,10 @@ namespace ngraph ...@@ -40,9 +40,10 @@ namespace ngraph
auto data = node.get_ng_inputs().at(0); auto data = node.get_ng_inputs().at(0);
std::shared_ptr<ngraph::Node> zero_node = std::shared_ptr<ngraph::Node> zero_node =
builder::make_constant(data->get_element_type(), data->get_shape(), 0.f); std::make_shared<ngraph::op::Constant>(
std::shared_ptr<ngraph::Node> one_node = data->get_element_type(), data->get_shape(), std::vector<float>{0.f});
builder::make_constant(data->get_element_type(), data->get_shape(), 1.f); std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), data->get_shape(), std::vector<float>{1.f});
std::shared_ptr<ngraph::Node> positive_val_node = std::shared_ptr<ngraph::Node> positive_val_node =
data + std::make_shared<ngraph::op::Log>( data + std::make_shared<ngraph::op::Log>(
......
...@@ -40,7 +40,7 @@ namespace ngraph ...@@ -40,7 +40,7 @@ namespace ngraph
{ {
inline NodeVector sum(const Node& node) inline NodeVector sum(const Node& node)
{ {
return variadic::make_ng_variadic_op_with_broadcast<ngraph::op::Add>(node); return variadic::make_ng_variadic_op<ngraph::op::v1::Add>(node);
} }
} // namespace set_8 } // namespace set_8
......
...@@ -38,9 +38,9 @@ namespace ngraph ...@@ -38,9 +38,9 @@ namespace ngraph
double alpha = node.get_attribute_value<double>("alpha", 1.0); double alpha = node.get_attribute_value<double>("alpha", 1.0);
std::shared_ptr<ngraph::Node> alpha_node = std::shared_ptr<ngraph::Node> alpha_node =
std::make_shared<ngraph::op::Constant>( std::make_shared<ngraph::op::Constant>(data->get_element_type(),
data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha}); data->get_shape(),
alpha_node = ngraph::op::make_broadcast_node(alpha_node, data->get_shape()); std::vector<double>{alpha});
auto data_map = std::make_shared<ngraph::op::Convert>( auto data_map = std::make_shared<ngraph::op::Convert>(
std::make_shared<ngraph::op::Greater>(data, alpha_node), std::make_shared<ngraph::op::Greater>(data, alpha_node),
......
...@@ -60,36 +60,6 @@ namespace ngraph ...@@ -60,36 +60,6 @@ namespace ngraph
return {result}; return {result};
} }
/// \brief Create an nGraph version of an ONNX variadic operation.
/// This creates a subgraph with a series of binary operations.
///
/// \param node Incoming ONNX opearation.
///
/// \tparam T Class of an nGraph binary operation (e.g. Add, Minimum, Maximum)
///
/// \return nGraph node equivalent of the ONNX operation
template <class T>
inline NodeVector make_ng_variadic_op_with_broadcast(const Node& node)
{
NodeVector ng_inputs{node.get_ng_inputs()};
// Templated binary operation - Creates Add, Minimum, Maximum, etc.
auto binary_operation = [](const std::shared_ptr<ngraph::Node>& arg0,
const std::shared_ptr<ngraph::Node>& arg1) {
NodeVector args{ngraph::op::numpy_style_broadcast({arg0, arg1})};
return std::make_shared<T>(args.at(0), args.at(1));
};
// Create a result node as a series of binary operations
auto result = std::accumulate(
std::next(std::begin(ng_inputs)), // First operand value - the second input
std::end(ng_inputs), // Last value - final input
ng_inputs.front(), // Initial value - first input
binary_operation);
return {result};
}
} // namespace variadic } // namespace variadic
} // namespace onnx_import } // namespace onnx_import
......
...@@ -19,23 +19,25 @@ ...@@ -19,23 +19,25 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
constexpr NodeTypeInfo op::Add::type_info; // ------------------------------- v0 ------------------------------------------
op::Add::Add(const Output<Node>& arg0, constexpr NodeTypeInfo op::v0::Add::type_info;
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast) op::v0::Add::Add(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast) : BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
shared_ptr<Node> op::Add::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v0::Add::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<Add>(new_args.at(0), new_args.at(1), this->get_autob()); return make_shared<op::v0::Add>(new_args.at(0), new_args.at(1), this->get_autob());
} }
void op::Add::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) void op::v0::Add::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{ {
if (get_autob().m_type != op::AutoBroadcastType::NONE) if (get_autob().m_type != op::AutoBroadcastType::NONE)
{ {
...@@ -55,3 +57,37 @@ shared_ptr<Node> ngraph::operator+(const Output<Node>& arg0, const Output<Node>& ...@@ -55,3 +57,37 @@ shared_ptr<Node> ngraph::operator+(const Output<Node>& arg0, const Output<Node>&
{ {
return make_shared<op::Add>(arg0, arg1); return make_shared<op::Add>(arg0, arg1);
} }
// ------------------------------- v1 ------------------------------------------
constexpr NodeTypeInfo op::v1::Add::type_info;
op::v1::Add::Add(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v1::Add::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::v1::Add>(new_args.at(0), new_args.at(1), this->get_autob());
}
void op::v1::Add::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
if (get_autob().m_type != op::AutoBroadcastType::NONE)
{
throw ngraph_error("Autodiff not supported with auto broadcasting");
}
auto delta = deltas.at(0);
auto x = input_value(0);
auto y = input_value(1);
adjoints.add_delta(x, delta);
adjoints.add_delta(y, delta);
}
...@@ -24,39 +24,83 @@ namespace ngraph ...@@ -24,39 +24,83 @@ namespace ngraph
{ {
namespace op namespace op
{ {
/// \brief Elementwise addition operation. namespace v0
///
class Add : public util::BinaryElementwiseArithmetic
{ {
public: /// \brief Elementwise addition operation.
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Add", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs an uninitialized addition operation
Add() = default;
/// \brief Constructs an addition operation.
/// ///
/// \param arg0 Output that produces the first input tensor.<br> class Add : public util::BinaryElementwiseArithmetic
/// `[d0, ...]` {
/// \param arg1 Output that produces the second input tensor.<br> public:
/// `[d0, ...]` NGRAPH_API
/// \param auto_broadcast Auto broadcast specification static constexpr NodeTypeInfo type_info{"Add", 0};
/// const NodeTypeInfo& get_type_info() const override { return type_info; }
/// Output `[d0, ...]` /// \brief Constructs an uninitialized addition operation
Add() = default;
/// \brief Constructs an addition operation.
///
/// \param arg0 Output that produces the first input tensor.<br>
/// `[d0, ...]`
/// \param arg1 Output that produces the second input tensor.<br>
/// `[d0, ...]`
/// \param auto_broadcast Auto broadcast specification
///
/// Output `[d0, ...]`
///
Add(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
} // namespace v0
namespace v1
{
/// \brief Elementwise addition operation.
/// ///
Add(const Output<Node>& arg0, class Add : public util::BinaryElementwiseArithmetic
const Output<Node>& arg1, {
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec()); public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Add", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs an uninitialized addition operation
Add() = default;
/// \brief Constructs an addition operation.
///
/// \param arg0 Output that produces the first input tensor.<br>
/// `[d0, ...]`
/// \param arg1 Output that produces the second input tensor.<br>
/// `[d0, ...]`
/// \param auto_broadcast Auto broadcast specification. Default is Numpy-style
/// implicit broadcasting.
///
/// Output `[d0, ...]`
///
Add(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; virtual bool is_commutative() const override { return true; }
size_t get_version() const override { return 1; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
} // namespace v1
virtual bool is_commutative() const override { return true; } using v0::Add;
protected: } // namespace op
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
}
std::shared_ptr<Node> operator+(const Output<Node>& arg0, const Output<Node>& arg1); std::shared_ptr<Node> operator+(const Output<Node>& arg0, const Output<Node>& arg1);
} } // namespace ngraph
...@@ -21,34 +21,36 @@ ...@@ -21,34 +21,36 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
constexpr NodeTypeInfo op::Divide::type_info; // ------------------------------ v0 -------------------------------------------
op::Divide::Divide(const Output<Node>& arg0, constexpr NodeTypeInfo op::v0::Divide::type_info;
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast) op::v0::Divide::Divide(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast) : BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::Divide::Divide(const Output<Node>& arg0, op::v0::Divide::Divide(const Output<Node>& arg0,
const Output<Node>& arg1, const Output<Node>& arg1,
bool pythondiv, bool pythondiv,
const AutoBroadcastSpec& auto_broadcast) const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast) : BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast)
, m_pythondiv(pythondiv) , m_pythondiv(pythondiv)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
shared_ptr<Node> op::Divide::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v0::Divide::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<Divide>( return make_shared<op::v0::Divide>(
new_args.at(0), new_args.at(1), this->is_pythondiv(), this->get_autob()); new_args.at(0), new_args.at(1), this->is_pythondiv(), this->get_autob());
} }
void op::Divide::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) void op::v0::Divide::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{ {
if (get_autob().m_type != op::AutoBroadcastType::NONE) if (get_autob().m_type != op::AutoBroadcastType::NONE)
{ {
...@@ -66,5 +68,50 @@ void op::Divide::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto ...@@ -66,5 +68,50 @@ void op::Divide::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto
shared_ptr<Node> ngraph::operator/(const Output<Node>& arg0, const Output<Node>& arg1) shared_ptr<Node> ngraph::operator/(const Output<Node>& arg0, const Output<Node>& arg1)
{ {
return make_shared<op::Divide>(arg0, arg1); return make_shared<op::v0::Divide>(arg0, arg1);
}
// ------------------------------ v1 -------------------------------------------
constexpr NodeTypeInfo op::v1::Divide::type_info;
op::v1::Divide::Divide(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
op::v1::Divide::Divide(const Output<Node>& arg0,
const Output<Node>& arg1,
bool pythondiv,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast)
, m_pythondiv(pythondiv)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v1::Divide::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::v1::Divide>(
new_args.at(0), new_args.at(1), this->is_pythondiv(), this->get_autob());
}
void op::v1::Divide::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
if (get_autob().m_type != op::AutoBroadcastType::NONE)
{
throw ngraph_error("Autodiff not supported with auto broadcasting");
}
auto delta = deltas.at(0);
auto x = input_value(0);
auto y = input_value(1);
adjoints.add_delta(x, delta / y);
adjoints.add_delta(y, -delta * shared_from_this() / y);
} }
...@@ -22,47 +22,98 @@ namespace ngraph ...@@ -22,47 +22,98 @@ namespace ngraph
{ {
namespace op namespace op
{ {
/// \brief Elementwise division operation. namespace v0
class Divide : public util::BinaryElementwiseArithmetic
{ {
public: /// \brief Elementwise division operation.
NGRAPH_API class Divide : public util::BinaryElementwiseArithmetic
static constexpr NodeTypeInfo type_info{"Divide", 0}; {
const NodeTypeInfo& get_type_info() const override { return type_info; } public:
/// \brief Constructs a division operation. NGRAPH_API
Divide() = default; static constexpr NodeTypeInfo type_info{"Divide", 0};
/// \brief Constructs a division operation. const NodeTypeInfo& get_type_info() const override { return type_info; }
/// /// \brief Constructs a division operation.
/// \param arg0 Node that produces the first input tensor. Divide() = default;
/// \param arg1 Node that produces the second input tensor. /// \brief Constructs a division operation.
/// \param pythondiv Use Python style rounding for integral type ///
/// \param auto_broadcast Auto broadcast specification /// \param arg0 Node that produces the first input tensor.
Divide(const Output<Node>& arg0, /// \param arg1 Node that produces the second input tensor.
const Output<Node>& arg1, /// \param pythondiv Use Python style rounding for integral type
bool pythondiv, /// \param auto_broadcast Auto broadcast specification
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec()); Divide(const Output<Node>& arg0,
const Output<Node>& arg1,
bool pythondiv,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
/// \brief Constructs a division operation. /// \brief Constructs a division operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor. /// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification /// \param auto_broadcast Auto broadcast specification
Divide(const Output<Node>& arg0, Divide(const Output<Node>& arg0,
const Output<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec()); const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
bool is_pythondiv() const { return m_pythondiv; } bool is_pythondiv() const { return m_pythondiv; }
void set_is_pythondiv(bool pythondiv) { m_pythondiv = pythondiv; } void set_is_pythondiv(bool pythondiv) { m_pythondiv = pythondiv; }
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
protected: protected:
bool m_pythondiv{true}; bool m_pythondiv{true};
}; };
} } // namespace v0
namespace v1
{
/// \brief Elementwise division operation.
class Divide : public util::BinaryElementwiseArithmetic
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Divide", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a division operation.
Divide() = default;
/// \brief Constructs a division operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param pythondiv Use Python style rounding for integral type
/// \param auto_broadcast Auto broadcast specification
Divide(const Output<Node>& arg0,
const Output<Node>& arg1,
bool pythondiv,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
/// \brief Constructs a division operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Divide(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
bool is_pythondiv() const { return m_pythondiv; }
void set_is_pythondiv(bool pythondiv) { m_pythondiv = pythondiv; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
size_t get_version() const override { return 1; }
protected:
bool m_pythondiv{true};
};
} // namespace v1
using v0::Divide;
} // namespace op
std::shared_ptr<Node> operator/(const Output<Node>& arg0, const Output<Node>& arg1); std::shared_ptr<Node> operator/(const Output<Node>& arg0, const Output<Node>& arg1);
} } // namespace ngraph
...@@ -19,18 +19,38 @@ ...@@ -19,18 +19,38 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
constexpr NodeTypeInfo op::Equal::type_info; //------------------------------- v0 -------------------------------------------
op::Equal::Equal(const Output<Node>& arg0, constexpr NodeTypeInfo op::v0::Equal::type_info;
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast) op::v0::Equal::Equal(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseComparison(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v0::Equal::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::v0::Equal>(new_args.at(0), new_args.at(1), this->get_autob());
}
//------------------------------- v1 -------------------------------------------
constexpr NodeTypeInfo op::v1::Equal::type_info;
op::v1::Equal::Equal(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseComparison(arg0, arg1, auto_broadcast) : BinaryElementwiseComparison(arg0, arg1, auto_broadcast)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
shared_ptr<Node> op::Equal::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v1::Equal::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<Equal>(new_args.at(0), new_args.at(1), this->get_autob()); return make_shared<op::v1::Equal>(new_args.at(0), new_args.at(1), this->get_autob());
} }
...@@ -22,44 +22,94 @@ namespace ngraph ...@@ -22,44 +22,94 @@ namespace ngraph
{ {
namespace op namespace op
{ {
// clang-format off namespace v0
/// \brief Elementwise is-equal operation.
///
/// ## Inputs
///
/// | | Type | Description |
/// | ------ | --------------------------------- | ------------------------------------------------------ |
/// | `arg0` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape and element type. |
/// | `arg1` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of the same shape and element type as `arg0`. |
/// | `autob`| AutoBroadcastSpec | Auto broadcast specification. |
///
/// ## Output
///
/// | Type | Description |
/// | ---------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------ |
/// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg0}[i_1,\dots,i_n] = \texttt{arg1}[i_1,\dots,i_n]\text{, else } 0\f$ |
// clang-format on
class Equal : public util::BinaryElementwiseComparison
{ {
public: // clang-format off
NGRAPH_API /// \brief Elementwise is-equal operation.
static constexpr NodeTypeInfo type_info{"Equal", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs an equal operation.
Equal() = default;
/// \brief Constructs an equal operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// ## Inputs
/// \param arg1 Node that produces the second input tensor. ///
/// \param auto_broadcast Auto broadcast specification /// | | Type | Description |
Equal(const Output<Node>& arg0, /// | ------ | --------------------------------- | ------------------------------------------------------ |
const Output<Node>& arg1, /// | `arg0` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape and element type. |
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec()); /// | `arg1` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of the same shape and element type as `arg0`. |
/// | `autob`| AutoBroadcastSpec | Auto broadcast specification. |
///
/// ## Output
///
/// | Type | Description |
/// | ---------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------ |
/// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg0}[i_1,\dots,i_n] = \texttt{arg1}[i_1,\dots,i_n]\text{, else } 0\f$ |
// clang-format on
class Equal : public util::BinaryElementwiseComparison
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Equal", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs an equal operation.
Equal() = default;
/// \brief Constructs an equal operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Equal(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
};
} // namespace v0
namespace v1
{
// clang-format off
/// \brief Elementwise is-equal operation.
///
/// ## Inputs
///
/// | | Type | Description |
/// | ------ | --------------------------------- | ------------------------------------------------------ |
/// | `arg0` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape and element type. |
/// | `arg1` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of the same shape and element type as `arg0`. |
/// | `autob`| AutoBroadcastSpec | Auto broadcast specification. |
///
/// ## Output
///
/// | Type | Description |
/// | ---------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------ |
/// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg0}[i_1,\dots,i_n] = \texttt{arg1}[i_1,\dots,i_n]\text{, else } 0\f$ |
// clang-format on
class Equal : public util::BinaryElementwiseComparison
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Equal", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs an equal operation.
Equal() = default;
/// \brief Constructs an equal operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Equal(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual std::shared_ptr<Node> virtual bool is_commutative() const override { return true; }
copy_with_new_args(const NodeVector& new_args) const override; size_t get_version() const override { return 1; }
};
} // namespace v1
virtual bool is_commutative() const override { return true; } using v0::Equal;
};
} }
} }
...@@ -19,18 +19,38 @@ ...@@ -19,18 +19,38 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
constexpr NodeTypeInfo op::Greater::type_info; //-------------------------------------- v0 ------------------------------------
op::Greater::Greater(const Output<Node>& arg0, constexpr NodeTypeInfo op::v0::Greater::type_info;
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast) op::v0::Greater::Greater(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseComparison(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v0::Greater::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::v0::Greater>(new_args.at(0), new_args.at(1), this->get_autob());
}
//-------------------------------------- v1 ------------------------------------
constexpr NodeTypeInfo op::v1::Greater::type_info;
op::v1::Greater::Greater(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseComparison(arg0, arg1, auto_broadcast) : BinaryElementwiseComparison(arg0, arg1, auto_broadcast)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
shared_ptr<Node> op::Greater::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v1::Greater::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<Greater>(new_args.at(0), new_args.at(1), this->get_autob()); return make_shared<op::v1::Greater>(new_args.at(0), new_args.at(1), this->get_autob());
} }
...@@ -22,26 +22,58 @@ namespace ngraph ...@@ -22,26 +22,58 @@ namespace ngraph
{ {
namespace op namespace op
{ {
/// \brief Elementwise greater-than operation. namespace v0
class Greater : public util::BinaryElementwiseComparison
{ {
public: /// \brief Elementwise greater-than operation.
NGRAPH_API class Greater : public util::BinaryElementwiseComparison
static constexpr NodeTypeInfo type_info{"Greater", 0}; {
const NodeTypeInfo& get_type_info() const override { return type_info; } public:
/// \brief Constructs a greater-than operation. NGRAPH_API
Greater() = default; static constexpr NodeTypeInfo type_info{"Greater", 0};
/// \brief Constructs a greater-than operation. const NodeTypeInfo& get_type_info() const override { return type_info; }
/// /// \brief Constructs a greater-than operation.
/// \param arg0 Node that produces the first input tensor. Greater() = default;
/// \param arg1 Node that produces the second input tensor. /// \brief Constructs a greater-than operation.
/// \param auto_broadcast Auto broadcast specification ///
Greater(const Output<Node>& arg0, /// \param arg0 Node that produces the first input tensor.
const Output<Node>& arg1, /// \param arg1 Node that produces the second input tensor.
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec()); /// \param auto_broadcast Auto broadcast specification
Greater(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
}; };
} // namespace v0
namespace v1
{
/// \brief Elementwise greater-than operation.
class Greater : public util::BinaryElementwiseComparison
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Greater", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a greater-than operation.
Greater() = default;
/// \brief Constructs a greater-than operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Greater(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
size_t get_version() const override { return 1; }
};
} // namespace v1
using v0::Greater;
} }
} }
...@@ -19,18 +19,38 @@ ...@@ -19,18 +19,38 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
constexpr NodeTypeInfo op::GreaterEq::type_info; //---------------------------------- v0 ----------------------------------------
op::GreaterEq::GreaterEq(const Output<Node>& arg0, constexpr NodeTypeInfo op::v0::GreaterEq::type_info;
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast) op::v0::GreaterEq::GreaterEq(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseComparison(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v0::GreaterEq::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::v0::GreaterEq>(new_args.at(0), new_args.at(1), this->get_autob());
}
//---------------------------------- v1 ----------------------------------------
constexpr NodeTypeInfo op::v1::GreaterEq::type_info;
op::v1::GreaterEq::GreaterEq(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseComparison(arg0, arg1, auto_broadcast) : BinaryElementwiseComparison(arg0, arg1, auto_broadcast)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
shared_ptr<Node> op::GreaterEq::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v1::GreaterEq::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<GreaterEq>(new_args.at(0), new_args.at(1), this->get_autob()); return make_shared<op::v1::GreaterEq>(new_args.at(0), new_args.at(1), this->get_autob());
} }
...@@ -22,26 +22,58 @@ namespace ngraph ...@@ -22,26 +22,58 @@ namespace ngraph
{ {
namespace op namespace op
{ {
/// \brief Elementwise greater-than-or-equal operation. namespace v0
class GreaterEq : public util::BinaryElementwiseComparison
{ {
public: /// \brief Elementwise greater-than-or-equal operation.
NGRAPH_API class GreaterEq : public util::BinaryElementwiseComparison
static constexpr NodeTypeInfo type_info{"GreaterEq", 0}; {
const NodeTypeInfo& get_type_info() const override { return type_info; } public:
/// \brief Constructs a greater-than-or-equal operation. NGRAPH_API
GreaterEq() = default; static constexpr NodeTypeInfo type_info{"GreaterEq", 0};
/// \brief Constructs a greater-than-or-equal operation. const NodeTypeInfo& get_type_info() const override { return type_info; }
/// /// \brief Constructs a greater-than-or-equal operation.
/// \param arg0 Node that produces the first input tensor. GreaterEq() = default;
/// \param arg1 Node that produces the second input tensor. /// \brief Constructs a greater-than-or-equal operation.
/// \param auto_broadcast Auto broadcast specification ///
GreaterEq(const Output<Node>& arg0, /// \param arg0 Node that produces the first input tensor.
const Output<Node>& arg1, /// \param arg1 Node that produces the second input tensor.
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec()); /// \param auto_broadcast Auto broadcast specification
GreaterEq(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
}; };
} // namespace v0
namespace v1
{
/// \brief Elementwise greater-than-or-equal operation.
class GreaterEq : public util::BinaryElementwiseComparison
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"GreaterEq", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a greater-than-or-equal operation.
GreaterEq() = default;
/// \brief Constructs a greater-than-or-equal operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
GreaterEq(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
size_t get_version() const override { return 1; }
};
} // namespace v1
using v0::GreaterEq;
} }
} }
...@@ -19,18 +19,38 @@ ...@@ -19,18 +19,38 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
constexpr NodeTypeInfo op::Less::type_info; // ----------------------------- v0 --------------------------------------------
op::Less::Less(const Output<Node>& arg0, constexpr NodeTypeInfo op::v0::Less::type_info;
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast) op::v0::Less::Less(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseComparison(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v0::Less::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::v0::Less>(new_args.at(0), new_args.at(1), this->get_autob());
}
// ----------------------------- v1 --------------------------------------------
constexpr NodeTypeInfo op::v1::Less::type_info;
op::v1::Less::Less(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseComparison(arg0, arg1, auto_broadcast) : BinaryElementwiseComparison(arg0, arg1, auto_broadcast)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
shared_ptr<Node> op::Less::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v1::Less::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<Less>(new_args.at(0), new_args.at(1), this->get_autob()); return make_shared<op::v1::Less>(new_args.at(0), new_args.at(1), this->get_autob());
} }
...@@ -22,26 +22,58 @@ namespace ngraph ...@@ -22,26 +22,58 @@ namespace ngraph
{ {
namespace op namespace op
{ {
/// \brief Elementwise less-than operation. namespace v0
class Less : public util::BinaryElementwiseComparison
{ {
public: /// \brief Elementwise less-than operation.
NGRAPH_API class Less : public util::BinaryElementwiseComparison
static constexpr NodeTypeInfo type_info{"Less", 0}; {
const NodeTypeInfo& get_type_info() const override { return type_info; } public:
/// \brief Constructs a less-than operation. NGRAPH_API
Less() = default; static constexpr NodeTypeInfo type_info{"Less", 0};
/// \brief Constructs a less-than operation. const NodeTypeInfo& get_type_info() const override { return type_info; }
/// /// \brief Constructs a less-than operation.
/// \param arg0 Node that produces the first input tensor. Less() = default;
/// \param arg1 Node that produces the second input tensor. /// \brief Constructs a less-than operation.
/// \param auto_broadcast Auto broadcast specification ///
Less(const Output<Node>& arg0, /// \param arg0 Node that produces the first input tensor.
const Output<Node>& arg1, /// \param arg1 Node that produces the second input tensor.
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec()); /// \param auto_broadcast Auto broadcast specification
Less(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
}; };
} // namespace v0
namespace v1
{
/// \brief Elementwise less-than operation.
class Less : public util::BinaryElementwiseComparison
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Less", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a less-than operation.
Less() = default;
/// \brief Constructs a less-than operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Less(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
size_t get_version() const override { return 1; }
};
} // namespace v1
using v0::Less;
} }
} }
...@@ -19,6 +19,8 @@ ...@@ -19,6 +19,8 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
// ---------------------------------- v1 ---------------------------------------
constexpr NodeTypeInfo op::v1::LessEqual::type_info; constexpr NodeTypeInfo op::v1::LessEqual::type_info;
op::v1::LessEqual::LessEqual(const Output<Node>& arg0, op::v1::LessEqual::LessEqual(const Output<Node>& arg0,
...@@ -35,6 +37,8 @@ shared_ptr<Node> op::v1::LessEqual::copy_with_new_args(const NodeVector& new_arg ...@@ -35,6 +37,8 @@ shared_ptr<Node> op::v1::LessEqual::copy_with_new_args(const NodeVector& new_arg
return make_shared<v1::LessEqual>(new_args.at(0), new_args.at(1), this->get_autob()); return make_shared<v1::LessEqual>(new_args.at(0), new_args.at(1), this->get_autob());
} }
// ---------------------------------- v0 ---------------------------------------
constexpr NodeTypeInfo op::v0::LessEq::type_info; constexpr NodeTypeInfo op::v0::LessEq::type_info;
op::v0::LessEq::LessEq(const Output<Node>& arg0, op::v0::LessEq::LessEq(const Output<Node>& arg0,
......
...@@ -33,6 +33,7 @@ namespace ngraph ...@@ -33,6 +33,7 @@ namespace ngraph
const NodeTypeInfo& get_type_info() const override { return type_info; } const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a less-than-or-equal operation. /// \brief Constructs a less-than-or-equal operation.
LessEqual() = default; LessEqual() = default;
/// \brief Constructs a less-than-or-equal operation. /// \brief Constructs a less-than-or-equal operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
...@@ -40,12 +41,14 @@ namespace ngraph ...@@ -40,12 +41,14 @@ namespace ngraph
/// \param auto_broadcast Auto broadcast specification /// \param auto_broadcast Auto broadcast specification
LessEqual(const Output<Node>& arg0, LessEqual(const Output<Node>& arg0,
const Output<Node>& arg1, const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec()); const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
}; };
} // namespace v1 } // namespace v1
namespace v0 namespace v0
{ {
/// \brief Elementwise less-than-or-equal operation. /// \brief Elementwise less-than-or-equal operation.
...@@ -57,6 +60,7 @@ namespace ngraph ...@@ -57,6 +60,7 @@ namespace ngraph
const NodeTypeInfo& get_type_info() const override { return type_info; } const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a less-than-or-equal operation. /// \brief Constructs a less-than-or-equal operation.
LessEq() = default; LessEq() = default;
/// \brief Constructs a less-than-or-equal operation. /// \brief Constructs a less-than-or-equal operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// \param arg0 Node that produces the first input tensor.
......
...@@ -25,23 +25,62 @@ ...@@ -25,23 +25,62 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
constexpr NodeTypeInfo op::Maximum::type_info; // ------------------------------------ v0 -------------------------------------
op::Maximum::Maximum(const Output<Node>& arg0, constexpr NodeTypeInfo op::v0::Maximum::type_info;
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast) op::v0::Maximum::Maximum(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v0::Maximum::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::v0::Maximum>(new_args.at(0), new_args.at(1), this->get_autob());
}
void op::v0::Maximum::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
if (get_autob().m_type != op::AutoBroadcastType::NONE)
{
throw ngraph_error("Autodiff not supported with auto broadcasting");
}
auto delta = deltas.at(0);
auto x = input_value(0);
auto y = input_value(1);
adjoints.add_delta(
x,
delta * make_shared<op::Convert>(make_shared<op::v0::Greater>(x, y), x.get_element_type()));
adjoints.add_delta(
y,
delta * make_shared<op::Convert>(make_shared<op::v0::Greater>(y, x), y.get_element_type()));
}
// ------------------------------------ v1 -------------------------------------
constexpr NodeTypeInfo op::v1::Maximum::type_info;
op::v1::Maximum::Maximum(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast) : BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
shared_ptr<Node> op::Maximum::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v1::Maximum::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<Maximum>(new_args.at(0), new_args.at(1), this->get_autob()); return make_shared<op::v1::Maximum>(new_args.at(0), new_args.at(1), this->get_autob());
} }
void op::Maximum::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) void op::v1::Maximum::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{ {
if (get_autob().m_type != op::AutoBroadcastType::NONE) if (get_autob().m_type != op::AutoBroadcastType::NONE)
{ {
...@@ -53,7 +92,9 @@ void op::Maximum::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect ...@@ -53,7 +92,9 @@ void op::Maximum::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect
auto x = input_value(0); auto x = input_value(0);
auto y = input_value(1); auto y = input_value(1);
adjoints.add_delta( adjoints.add_delta(
x, delta * make_shared<op::Convert>(make_shared<op::Greater>(x, y), x.get_element_type())); x,
delta * make_shared<op::Convert>(make_shared<op::v1::Greater>(x, y), x.get_element_type()));
adjoints.add_delta( adjoints.add_delta(
y, delta * make_shared<op::Convert>(make_shared<op::Greater>(y, x), y.get_element_type())); y,
delta * make_shared<op::Convert>(make_shared<op::v1::Greater>(y, x), y.get_element_type()));
} }
...@@ -22,31 +22,68 @@ namespace ngraph ...@@ -22,31 +22,68 @@ namespace ngraph
{ {
namespace op namespace op
{ {
/// \brief Elementwise maximum operation. namespace v0
class Maximum : public util::BinaryElementwiseArithmetic
{ {
public: /// \brief Elementwise maximum operation.
NGRAPH_API class Maximum : public util::BinaryElementwiseArithmetic
static constexpr NodeTypeInfo type_info{"Maximum", 0}; {
const NodeTypeInfo& get_type_info() const override { return type_info; } public:
/// \brief Constructs a maximum operation. NGRAPH_API
Maximum() = default; static constexpr NodeTypeInfo type_info{"Maximum", 0};
/// \brief Constructs a maximum operation. const NodeTypeInfo& get_type_info() const override { return type_info; }
/// /// \brief Constructs a maximum operation.
/// \param arg0 Node that produces the first input tensor. Maximum() = default;
/// \param arg1 Node that produces the second input tensor. /// \brief Constructs a maximum operation.
/// \param auto_broadcast Auto broadcast specification ///
Maximum(const Output<Node>& arg0, /// \param arg0 Node that produces the first input tensor.
const Output<Node>& arg1, /// \param arg1 Node that produces the second input tensor.
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec()); /// \param auto_broadcast Auto broadcast specification
Maximum(const Output<Node>& arg0,
virtual std::shared_ptr<Node> const Output<Node>& arg1,
copy_with_new_args(const NodeVector& new_args) const override; const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
virtual bool is_commutative() const override { return true; } virtual std::shared_ptr<Node>
protected: copy_with_new_args(const NodeVector& new_args) const override;
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; virtual bool is_commutative() const override { return true; }
}; protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
} // namespace v0
namespace v1
{
/// \brief Elementwise maximum operation.
class Maximum : public util::BinaryElementwiseArithmetic
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Maximum", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a maximum operation.
Maximum() = default;
/// \brief Constructs a maximum operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Maximum(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
size_t get_version() const override { return 1; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
} // namespace v1
using v0::Maximum;
} }
} }
...@@ -25,23 +25,61 @@ ...@@ -25,23 +25,61 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
constexpr NodeTypeInfo op::Minimum::type_info; // ------------------------------ v0 -------------------------------------------
op::Minimum::Minimum(const Output<Node>& arg0, constexpr NodeTypeInfo op::v0::Minimum::type_info;
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast) op::v0::Minimum::Minimum(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v0::Minimum::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::v0::Minimum>(new_args.at(0), new_args.at(1), this->get_autob());
}
void op::v0::Minimum::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
if (get_autob().m_type != op::AutoBroadcastType::NONE)
{
throw ngraph_error("Autodiff not supported with auto broadcasting");
}
auto delta = deltas.at(0);
auto x = input_value(0);
auto y = input_value(1);
adjoints.add_delta(
x, delta * make_shared<op::Convert>(make_shared<op::v0::Less>(x, y), x.get_element_type()));
adjoints.add_delta(
y, delta * make_shared<op::Convert>(make_shared<op::v0::Less>(y, x), y.get_element_type()));
}
// ------------------------------ v1 -------------------------------------------
constexpr NodeTypeInfo op::v1::Minimum::type_info;
op::v1::Minimum::Minimum(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast) : BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
shared_ptr<Node> op::Minimum::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v1::Minimum::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<Minimum>(new_args.at(0), new_args.at(1), this->get_autob()); return make_shared<op::v1::Minimum>(new_args.at(0), new_args.at(1), this->get_autob());
} }
void op::Minimum::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) void op::v1::Minimum::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{ {
if (get_autob().m_type != op::AutoBroadcastType::NONE) if (get_autob().m_type != op::AutoBroadcastType::NONE)
{ {
...@@ -54,7 +92,7 @@ void op::Minimum::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect ...@@ -54,7 +92,7 @@ void op::Minimum::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect
auto y = input_value(1); auto y = input_value(1);
adjoints.add_delta( adjoints.add_delta(
x, delta * make_shared<op::Convert>(make_shared<op::Less>(x, y), x.get_element_type())); x, delta * make_shared<op::Convert>(make_shared<op::v1::Less>(x, y), x.get_element_type()));
adjoints.add_delta( adjoints.add_delta(
y, delta * make_shared<op::Convert>(make_shared<op::Less>(y, x), y.get_element_type())); y, delta * make_shared<op::Convert>(make_shared<op::v1::Less>(y, x), y.get_element_type()));
} }
...@@ -22,31 +22,68 @@ namespace ngraph ...@@ -22,31 +22,68 @@ namespace ngraph
{ {
namespace op namespace op
{ {
/// \brief Elementwise minimum operation. namespace v0
class Minimum : public util::BinaryElementwiseArithmetic
{ {
public: /// \brief Elementwise minimum operation.
NGRAPH_API class Minimum : public util::BinaryElementwiseArithmetic
static constexpr NodeTypeInfo type_info{"Minimum", 0}; {
const NodeTypeInfo& get_type_info() const override { return type_info; } public:
/// \brief Constructs a minimum operation. NGRAPH_API
Minimum() = default; static constexpr NodeTypeInfo type_info{"Minimum", 0};
/// \brief Constructs a minimum operation. const NodeTypeInfo& get_type_info() const override { return type_info; }
/// /// \brief Constructs a minimum operation.
/// \param arg0 Node that produces the first input tensor. Minimum() = default;
/// \param arg1 Node that produces the second input tensor. /// \brief Constructs a minimum operation.
/// \param auto_broadcast Auto broadcast specification ///
Minimum(const Output<Node>& arg0, /// \param arg0 Node that produces the first input tensor.
const Output<Node>& arg1, /// \param arg1 Node that produces the second input tensor.
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec()); /// \param auto_broadcast Auto broadcast specification
Minimum(const Output<Node>& arg0,
virtual std::shared_ptr<Node> const Output<Node>& arg1,
copy_with_new_args(const NodeVector& new_args) const override; const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
virtual bool is_commutative() const override { return true; } virtual std::shared_ptr<Node>
protected: copy_with_new_args(const NodeVector& new_args) const override;
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; virtual bool is_commutative() const override { return true; }
}; protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
} // namespace v0
namespace v1
{
/// \brief Elementwise minimum operation.
class Minimum : public util::BinaryElementwiseArithmetic
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Minimum", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a minimum operation.
Minimum() = default;
/// \brief Constructs a minimum operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Minimum(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
size_t get_version() const override { return 1; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
} // namespace v1
using v0::Minimum;
} }
} }
...@@ -19,23 +19,59 @@ ...@@ -19,23 +19,59 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
constexpr NodeTypeInfo op::Multiply::type_info; // ------------------------------------ v0 -------------------------------------
op::Multiply::Multiply(const Output<Node>& arg0, constexpr NodeTypeInfo op::v0::Multiply::type_info;
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast) op::v0::Multiply::Multiply(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v0::Multiply::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::v0::Multiply>(new_args.at(0), new_args.at(1), this->get_autob());
}
void op::v0::Multiply::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
if (get_autob().m_type != op::AutoBroadcastType::NONE)
{
throw ngraph_error("Autodiff not supported with auto broadcasting");
}
auto delta = deltas.at(0);
auto x = input_value(0);
auto y = input_value(1);
adjoints.add_delta(x, delta * y);
adjoints.add_delta(y, x * delta);
}
// ------------------------------------ v1 -------------------------------------
constexpr NodeTypeInfo op::v1::Multiply::type_info;
op::v1::Multiply::Multiply(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast) : BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
shared_ptr<Node> op::Multiply::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v1::Multiply::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<Multiply>(new_args.at(0), new_args.at(1), this->get_autob()); return make_shared<op::v1::Multiply>(new_args.at(0), new_args.at(1), this->get_autob());
} }
void op::Multiply::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) void op::v1::Multiply::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{ {
if (get_autob().m_type != op::AutoBroadcastType::NONE) if (get_autob().m_type != op::AutoBroadcastType::NONE)
{ {
...@@ -51,6 +87,8 @@ void op::Multiply::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec ...@@ -51,6 +87,8 @@ void op::Multiply::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec
adjoints.add_delta(y, x * delta); adjoints.add_delta(y, x * delta);
} }
// -----------------------------------------------------------------------------
shared_ptr<Node> ngraph::operator*(const Output<Node>& arg0, const Output<Node>& arg1) shared_ptr<Node> ngraph::operator*(const Output<Node>& arg0, const Output<Node>& arg1)
{ {
return make_shared<op::Multiply>(arg0, arg1); return make_shared<op::Multiply>(arg0, arg1);
......
...@@ -22,33 +22,70 @@ namespace ngraph ...@@ -22,33 +22,70 @@ namespace ngraph
{ {
namespace op namespace op
{ {
/// \brief Elementwise multiplication operation. namespace v0
class Multiply : public util::BinaryElementwiseArithmetic
{ {
public: /// \brief Elementwise multiplication operation.
NGRAPH_API class Multiply : public util::BinaryElementwiseArithmetic
static constexpr NodeTypeInfo type_info{"Multiply", 0}; {
const NodeTypeInfo& get_type_info() const override { return type_info; } public:
/// \brief Constructs a multiplication operation. NGRAPH_API
Multiply() = default; static constexpr NodeTypeInfo type_info{"Multiply", 0};
/// \brief Constructs a multiplication operation. const NodeTypeInfo& get_type_info() const override { return type_info; }
/// /// \brief Constructs a multiplication operation.
/// \param arg0 Node that produces the first input tensor. Multiply() = default;
/// \param arg1 Node that produces the second input tensor. /// \brief Constructs a multiplication operation.
/// \param auto_broadcast Auto broadcast specification ///
Multiply(const Output<Node>& arg0, /// \param arg0 Node that produces the first input tensor.
const Output<Node>& arg1, /// \param arg1 Node that produces the second input tensor.
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec()); /// \param auto_broadcast Auto broadcast specification
Multiply(const Output<Node>& arg0,
virtual std::shared_ptr<Node> const Output<Node>& arg1,
copy_with_new_args(const NodeVector& new_args) const override; const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
virtual bool is_commutative() const override { return true; } virtual std::shared_ptr<Node>
protected: copy_with_new_args(const NodeVector& new_args) const override;
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; virtual bool is_commutative() const override { return true; }
}; protected:
}; virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
} // namespace v0
namespace v1
{
/// \brief Elementwise multiplication operation.
class Multiply : public util::BinaryElementwiseArithmetic
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Multiply", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a multiplication operation.
Multiply() = default;
/// \brief Constructs a multiplication operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Multiply(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
size_t get_version() const override { return 1; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
} // namespace v1
using v0::Multiply;
} // namespace op
std::shared_ptr<Node> operator*(const Output<Node>& arg0, const Output<Node>& arg1); std::shared_ptr<Node> operator*(const Output<Node>& arg0, const Output<Node>& arg1);
} } // namespace ngraph
...@@ -19,18 +19,38 @@ ...@@ -19,18 +19,38 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
constexpr NodeTypeInfo op::NotEqual::type_info; // ----------------------------------- v0 --------------------------------------
op::NotEqual::NotEqual(const Output<Node>& arg0, constexpr NodeTypeInfo op::v0::NotEqual::type_info;
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast) op::v0::NotEqual::NotEqual(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseComparison(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v0::NotEqual::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::v0::NotEqual>(new_args.at(0), new_args.at(1), this->get_autob());
}
// ----------------------------------- v1 --------------------------------------
constexpr NodeTypeInfo op::v1::NotEqual::type_info;
op::v1::NotEqual::NotEqual(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseComparison(arg0, arg1, auto_broadcast) : BinaryElementwiseComparison(arg0, arg1, auto_broadcast)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
shared_ptr<Node> op::NotEqual::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v1::NotEqual::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<NotEqual>(new_args.at(0), new_args.at(1), this->get_autob()); return make_shared<op::v1::NotEqual>(new_args.at(0), new_args.at(1), this->get_autob());
} }
...@@ -22,28 +22,62 @@ namespace ngraph ...@@ -22,28 +22,62 @@ namespace ngraph
{ {
namespace op namespace op
{ {
/// \brief Elementwise not-equal operation. namespace v0
class NotEqual : public util::BinaryElementwiseComparison
{ {
public: /// \brief Elementwise not-equal operation.
NGRAPH_API class NotEqual : public util::BinaryElementwiseComparison
static constexpr NodeTypeInfo type_info{"NotEqual", 0}; {
const NodeTypeInfo& get_type_info() const override { return type_info; } public:
/// \brief Constructs a not-equal operation. NGRAPH_API
NotEqual() = default; static constexpr NodeTypeInfo type_info{"NotEqual", 0};
/// \brief Constructs a not-equal operation. const NodeTypeInfo& get_type_info() const override { return type_info; }
/// /// \brief Constructs a not-equal operation.
/// \param arg0 Node that produces the first input tensor. NotEqual() = default;
/// \param arg1 Node that produces the second input tensor. /// \brief Constructs a not-equal operation.
/// \param auto_broadcast Auto broadcast specification ///
NotEqual(const Output<Node>& arg0, /// \param arg0 Node that produces the first input tensor.
const Output<Node>& arg1, /// \param arg1 Node that produces the second input tensor.
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec()); /// \param auto_broadcast Auto broadcast specification
NotEqual(const Output<Node>& arg0,
virtual std::shared_ptr<Node> const Output<Node>& arg1,
copy_with_new_args(const NodeVector& new_args) const override; const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
virtual bool is_commutative() const override { return true; } virtual std::shared_ptr<Node>
}; copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
};
} // namespace v0
namespace v1
{
/// \brief Elementwise not-equal operation.
class NotEqual : public util::BinaryElementwiseComparison
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"NotEqual", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a not-equal operation.
NotEqual() = default;
/// \brief Constructs a not-equal operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
NotEqual(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
size_t get_version() const override { return 1; }
};
} // namespace v1
using v0::NotEqual;
} }
} }
...@@ -22,23 +22,61 @@ ...@@ -22,23 +22,61 @@
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
constexpr NodeTypeInfo op::Power::type_info; // ------------------------------ v0 -------------------------------------------
op::Power::Power(const Output<Node>& arg0, constexpr NodeTypeInfo op::v0::Power::type_info;
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast) op::v0::Power::Power(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v0::Power::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::v0::Power>(new_args.at(0), new_args.at(1), this->get_autob());
}
void op::v0::Power::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
if (get_autob().m_type != op::AutoBroadcastType::NONE)
{
throw ngraph_error("Autodiff not supported with auto broadcasting");
}
auto delta = deltas.at(0);
auto x = input_value(0);
auto y = input_value(1);
auto log_x = make_shared<op::Log>(x);
adjoints.add_delta(x, delta * y * shared_from_this() / x);
adjoints.add_delta(y, delta * shared_from_this() * log_x);
}
// ------------------------------ v1 -------------------------------------------
constexpr NodeTypeInfo op::v1::Power::type_info;
op::v1::Power::Power(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast) : BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
shared_ptr<Node> op::Power::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v1::Power::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<Power>(new_args.at(0), new_args.at(1), this->get_autob()); return make_shared<op::v1::Power>(new_args.at(0), new_args.at(1), this->get_autob());
} }
void op::Power::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) void op::v1::Power::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{ {
if (get_autob().m_type != op::AutoBroadcastType::NONE) if (get_autob().m_type != op::AutoBroadcastType::NONE)
{ {
......
...@@ -22,44 +22,93 @@ namespace ngraph ...@@ -22,44 +22,93 @@ namespace ngraph
{ {
namespace op namespace op
{ {
// clang-format off namespace v0
/// \brief Elementwise exponentiation operation.
///
/// ## Inputs
///
/// | | Type | Description |
/// | ------ | --------------------------------- | ------------------------------------------------------ |
/// | `arg0` | \f$N[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape and numeric element type. |
/// | `arg1` | \f$N[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of the same shape and element type as `arg0`. |
///
/// ## Output
///
/// | Type | Description |
/// | ---------------------- | -------------------------------------------------------------------------------------------------------------- |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg0}[i_1,\dots,i_n]^{\texttt{arg1}[i_1,\dots,i_n]}\f$ |
// clang-format on
class Power : public util::BinaryElementwiseArithmetic
{ {
public: // clang-format off
NGRAPH_API /// \brief Elementwise exponentiation operation.
static constexpr NodeTypeInfo type_info{"Power", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
Power() = default;
/// \brief Constructs an exponentiation operation.
/// ///
/// \param arg0 Node that produces the first input tensor. /// ## Inputs
/// \param arg1 Node that produces the second input tensor. ///
/// \param auto_broadcast Auto broadcast specification /// | | Type | Description |
Power(const Output<Node>& arg0, /// | ------ | --------------------------------- | ------------------------------------------------------ |
const Output<Node>& arg1, /// | `arg0` | \f$N[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape and numeric element type. |
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec()); /// | `arg1` | \f$N[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of the same shape and element type as `arg0`. |
///
/// ## Output
///
/// | Type | Description |
/// | ---------------------- | -------------------------------------------------------------------------------------------------------------- |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg0}[i_1,\dots,i_n]^{\texttt{arg1}[i_1,\dots,i_n]}\f$ |
// clang-format on
class Power : public util::BinaryElementwiseArithmetic
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Power", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
Power() = default;
/// \brief Constructs an exponentiation operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Power(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
} // namespace v0
namespace v1
{
// clang-format off
/// \brief Elementwise exponentiation operation.
///
/// ## Inputs
///
/// | | Type | Description |
/// | ------ | --------------------------------- | ------------------------------------------------------ |
/// | `arg0` | \f$N[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape and numeric element type. |
/// | `arg1` | \f$N[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of the same shape and element type as `arg0`. |
///
/// ## Output
///
/// | Type | Description |
/// | ---------------------- | -------------------------------------------------------------------------------------------------------------- |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg0}[i_1,\dots,i_n]^{\texttt{arg1}[i_1,\dots,i_n]}\f$ |
// clang-format on
class Power : public util::BinaryElementwiseArithmetic
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Power", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
Power() = default;
/// \brief Constructs an exponentiation operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Power(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
size_t get_version() const override { return 1; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
} // namespace v1
protected: using v0::Power;
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
} }
} }
...@@ -18,19 +18,30 @@ ...@@ -18,19 +18,30 @@
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/and.hpp" #include "ngraph/op/and.hpp"
#include "ngraph/op/avg_pool.hpp" #include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp" #include "ngraph/op/convolution.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/equal.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp" #include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/generate_mask.hpp" #include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/less_eq.hpp" #include "ngraph/op/less_eq.hpp"
#include "ngraph/op/max_pool.hpp" #include "ngraph/op/max_pool.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/not.hpp" #include "ngraph/op/not.hpp"
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/or.hpp" #include "ngraph/op/or.hpp"
#include "ngraph/op/pad.hpp" #include "ngraph/op/pad.hpp"
#include "ngraph/op/power.hpp"
#include "ngraph/op/product.hpp" #include "ngraph/op/product.hpp"
#include "ngraph/op/reduce_prod.hpp" #include "ngraph/op/reduce_prod.hpp"
#include "ngraph/op/reduce_sum.hpp" #include "ngraph/op/reduce_sum.hpp"
...@@ -79,6 +90,17 @@ static OP_TYPEID get_typeid(shared_ptr<Node> node) ...@@ -79,6 +90,17 @@ static OP_TYPEID get_typeid(shared_ptr<Node> node)
} }
// END mapping to OP_TYPEID // END mapping to OP_TYPEID
template <typename OpV0, typename OpV1>
void downgrade_binary_elementwise_node(const shared_ptr<Node>& node)
{
const auto tmp = as_type_ptr<OpV1>(node);
const auto input_arg0 = node->input(0).get_source_output();
const auto input_arg1 = node->input(1).get_source_output();
const auto autob = tmp->get_autob();
auto replacement_node = make_shared<OpV0>(input_arg0, input_arg1, autob);
replace_node(node, replacement_node);
}
bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node) bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
{ {
bool modified = false; bool modified = false;
...@@ -104,6 +126,12 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node) ...@@ -104,6 +126,12 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
#endif #endif
switch (get_typeid(node)) switch (get_typeid(node))
{ {
case OP_TYPEID::Add:
{
downgrade_binary_elementwise_node<op::v0::Add, op::v1::Add>(node);
modified = true;
break;
}
case OP_TYPEID::AvgPool: case OP_TYPEID::AvgPool:
{ {
const auto tmp = as_type_ptr<op::v1::AvgPool>(node); const auto tmp = as_type_ptr<op::v1::AvgPool>(node);
...@@ -250,6 +278,18 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node) ...@@ -250,6 +278,18 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
modified = true; modified = true;
break; break;
} }
case OP_TYPEID::Divide:
{
const auto tmp = as_type_ptr<op::v1::Divide>(node);
const auto input_arg0 = node->input(0).get_source_output();
const auto input_arg1 = node->input(1).get_source_output();
const auto autob = tmp->get_autob();
const bool pydiv = tmp->is_pythondiv();
auto replacement_node = make_shared<op::v0::Divide>(input_arg0, input_arg1, pydiv, autob);
replace_node(node, replacement_node);
modified = true;
break;
}
case OP_TYPEID::DynReshape: case OP_TYPEID::DynReshape:
{ {
auto tmp = as_type_ptr<op::v1::Reshape>(node); auto tmp = as_type_ptr<op::v1::Reshape>(node);
...@@ -260,6 +300,12 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node) ...@@ -260,6 +300,12 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
modified = true; modified = true;
break; break;
} }
case OP_TYPEID::Equal:
{
downgrade_binary_elementwise_node<op::v0::Equal, op::v1::Equal>(node);
modified = true;
break;
}
case OP_TYPEID::GenerateMask: case OP_TYPEID::GenerateMask:
{ {
auto tmp = dynamic_cast<const op::v1::GenerateMask*>(node.get()); auto tmp = dynamic_cast<const op::v1::GenerateMask*>(node.get());
...@@ -279,23 +325,33 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node) ...@@ -279,23 +325,33 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
modified = true; modified = true;
break; break;
} }
case OP_TYPEID::Greater:
{
downgrade_binary_elementwise_node<op::v0::Greater, op::v1::Greater>(node);
modified = true;
break;
}
case OP_TYPEID::GreaterEq:
{
downgrade_binary_elementwise_node<op::v0::GreaterEq, op::v1::GreaterEq>(node);
modified = true;
break;
}
case OP_TYPEID::Less:
{
downgrade_binary_elementwise_node<op::v0::Less, op::v1::Less>(node);
modified = true;
break;
}
case OP_TYPEID::LessEqual: case OP_TYPEID::LessEqual:
{ {
auto less_eq_v1 = as_type_ptr<op::v1::LessEqual>(node); downgrade_binary_elementwise_node<op::v0::LessEq, op::v1::LessEqual>(node);
auto replacement_node = make_shared<op::v0::LessEq>(node->input(0).get_source_output(),
node->input(1).get_source_output(),
less_eq_v1->get_autob());
replace_node(node, replacement_node);
modified = true; modified = true;
break; break;
} }
case OP_TYPEID::LogicalAnd: case OP_TYPEID::LogicalAnd:
{ {
auto and_v1 = as_type_ptr<op::v1::LogicalAnd>(node); downgrade_binary_elementwise_node<op::v0::And, op::v1::LogicalAnd>(node);
auto replacement_node = make_shared<op::v0::And>(node->input(0).get_source_output(),
node->input(1).get_source_output(),
and_v1->get_autob());
replace_node(node, replacement_node);
modified = true; modified = true;
break; break;
} }
...@@ -307,21 +363,19 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node) ...@@ -307,21 +363,19 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
} }
case OP_TYPEID::LogicalOr: case OP_TYPEID::LogicalOr:
{ {
auto or_v1 = as_type_ptr<op::v1::LogicalOr>(node); downgrade_binary_elementwise_node<op::v0::Or, op::v1::LogicalOr>(node);
auto replacement_node = make_shared<op::v0::Or>(node->input(0).get_source_output(),
node->input(1).get_source_output(),
or_v1->get_autob());
replace_node(node, replacement_node);
modified = true; modified = true;
break; break;
} }
case OP_TYPEID::LogicalXor: case OP_TYPEID::LogicalXor:
{ {
auto xor_v1 = as_type_ptr<op::v1::LogicalXor>(node); downgrade_binary_elementwise_node<op::v0::Xor, op::v1::LogicalXor>(node);
auto replacement_node = make_shared<op::v0::Xor>(node->input(0).get_source_output(), modified = true;
node->input(1).get_source_output(), break;
xor_v1->get_autob()); }
replace_node(node, replacement_node); case OP_TYPEID::Maximum:
{
downgrade_binary_elementwise_node<op::v0::Maximum, op::v1::Maximum>(node);
modified = true; modified = true;
break; break;
} }
...@@ -385,6 +439,24 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node) ...@@ -385,6 +439,24 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
modified = true; modified = true;
break; break;
} }
case OP_TYPEID::Minimum:
{
downgrade_binary_elementwise_node<op::v0::Minimum, op::v1::Minimum>(node);
modified = true;
break;
}
case OP_TYPEID::Multiply:
{
downgrade_binary_elementwise_node<op::v0::Multiply, op::v1::Multiply>(node);
modified = true;
break;
}
case OP_TYPEID::NotEqual:
{
downgrade_binary_elementwise_node<op::v0::NotEqual, op::v1::NotEqual>(node);
modified = true;
break;
}
case OP_TYPEID::Pad: case OP_TYPEID::Pad:
{ {
auto tmp = as_type_ptr<op::v1::Pad>(node); auto tmp = as_type_ptr<op::v1::Pad>(node);
...@@ -397,6 +469,12 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node) ...@@ -397,6 +469,12 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
modified = true; modified = true;
break; break;
} }
case OP_TYPEID::Power:
{
downgrade_binary_elementwise_node<op::v0::Power, op::v1::Power>(node);
modified = true;
break;
}
case OP_TYPEID::Product: case OP_TYPEID::Product:
{ {
auto tmp = as_type_ptr<op::v1::ReduceProd>(node); auto tmp = as_type_ptr<op::v1::ReduceProd>(node);
......
...@@ -15,19 +15,30 @@ ...@@ -15,19 +15,30 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/pass/opset1_upgrade.hpp" #include "ngraph/pass/opset1_upgrade.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/and.hpp" #include "ngraph/op/and.hpp"
#include "ngraph/op/avg_pool.hpp" #include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp" #include "ngraph/op/convolution.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/equal.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp" #include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/gather.hpp" #include "ngraph/op/gather.hpp"
#include "ngraph/op/get_output_element.hpp" #include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/less_eq.hpp" #include "ngraph/op/less_eq.hpp"
#include "ngraph/op/max_pool.hpp" #include "ngraph/op/max_pool.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/not.hpp" #include "ngraph/op/not.hpp"
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/or.hpp" #include "ngraph/op/or.hpp"
#include "ngraph/op/pad.hpp" #include "ngraph/op/pad.hpp"
#include "ngraph/op/power.hpp"
#include "ngraph/op/product.hpp" #include "ngraph/op/product.hpp"
#include "ngraph/op/reduce_prod.hpp" #include "ngraph/op/reduce_prod.hpp"
#include "ngraph/op/reduce_sum.hpp" #include "ngraph/op/reduce_sum.hpp"
...@@ -77,6 +88,16 @@ static OP_TYPEID get_typeid(shared_ptr<Node> node) ...@@ -77,6 +88,16 @@ static OP_TYPEID get_typeid(shared_ptr<Node> node)
} }
// END mapping to OP_TYPEID // END mapping to OP_TYPEID
template <typename OpV0, typename OpV1>
void upgrade_binary_elementwise_node(const shared_ptr<Node>& node)
{
const auto tmp = dynamic_cast<const OpV0*>(node.get());
const auto autob = tmp->get_autob();
auto replacement_node = make_shared<OpV1>(
node->input(0).get_source_output(), node->input(1).get_source_output(), autob);
replace_node(node, replacement_node);
}
bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node) bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
{ {
bool modified = false; bool modified = false;
...@@ -102,13 +123,15 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node) ...@@ -102,13 +123,15 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
#endif #endif
switch (get_typeid(node)) switch (get_typeid(node))
{ {
case OP_TYPEID::Add:
{
upgrade_binary_elementwise_node<op::v0::Add, op::v1::Add>(node);
modified = true;
break;
}
case OP_TYPEID::And: case OP_TYPEID::And:
{ {
const auto and_v0 = dynamic_cast<const op::v0::And*>(node.get()); upgrade_binary_elementwise_node<op::v0::And, op::v1::LogicalAnd>(node);
auto replacement_node = make_shared<op::v1::LogicalAnd>(node->input(0).get_source_output(),
node->input(1).get_source_output(),
and_v0->get_autob());
replace_node(node, replacement_node);
modified = true; modified = true;
break; break;
} }
...@@ -284,6 +307,17 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node) ...@@ -284,6 +307,17 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
modified = true; modified = true;
break; break;
} }
case OP_TYPEID::Divide:
{
const auto tmp = dynamic_cast<const op::v0::Divide*>(node.get());
const auto autob = tmp->get_autob();
const bool pydiv = tmp->is_pythondiv();
auto replacement_node = make_shared<op::v1::Divide>(
node->input(0).get_source_output(), node->input(1).get_source_output(), pydiv, autob);
replace_node(node, replacement_node);
modified = true;
break;
}
case OP_TYPEID::DynReshape: case OP_TYPEID::DynReshape:
{ {
auto zero_flag = false; auto zero_flag = false;
...@@ -293,6 +327,12 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node) ...@@ -293,6 +327,12 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
modified = true; modified = true;
break; break;
} }
case OP_TYPEID::Equal:
{
upgrade_binary_elementwise_node<op::v0::Equal, op::v1::Equal>(node);
modified = true;
break;
}
case OP_TYPEID::Gather: case OP_TYPEID::Gather:
{ {
auto tmp = dynamic_cast<const op::v0::Gather*>(node.get()); auto tmp = dynamic_cast<const op::v0::Gather*>(node.get());
...@@ -305,13 +345,33 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node) ...@@ -305,13 +345,33 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
modified = true; modified = true;
break; break;
} }
case OP_TYPEID::Greater:
{
upgrade_binary_elementwise_node<op::v0::Greater, op::v1::Greater>(node);
modified = true;
break;
}
case OP_TYPEID::GreaterEq:
{
upgrade_binary_elementwise_node<op::v0::GreaterEq, op::v1::GreaterEq>(node);
modified = true;
break;
}
case OP_TYPEID::Less:
{
upgrade_binary_elementwise_node<op::v0::Less, op::v1::Less>(node);
modified = true;
break;
}
case OP_TYPEID::LessEq: case OP_TYPEID::LessEq:
{ {
const auto less_eq_v0 = dynamic_cast<const op::v0::LessEq*>(node.get()); upgrade_binary_elementwise_node<op::v0::LessEq, op::v1::LessEqual>(node);
auto replacement_node = make_shared<op::v1::LessEqual>(node->input(0).get_source_output(), modified = true;
node->input(1).get_source_output(), break;
less_eq_v0->get_autob()); }
replace_node(node, replacement_node); case OP_TYPEID::Maximum:
{
upgrade_binary_elementwise_node<op::v0::Maximum, op::v1::Maximum>(node);
modified = true; modified = true;
break; break;
} }
...@@ -372,19 +432,33 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node) ...@@ -372,19 +432,33 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
modified = true; modified = true;
break; break;
} }
case OP_TYPEID::Minimum:
{
upgrade_binary_elementwise_node<op::v0::Minimum, op::v1::Minimum>(node);
modified = true;
break;
}
case OP_TYPEID::Multiply:
{
upgrade_binary_elementwise_node<op::v0::Multiply, op::v1::Multiply>(node);
modified = true;
break;
}
case OP_TYPEID::Not: case OP_TYPEID::Not:
{ {
replace_node(node, make_shared<op::v1::LogicalNot>(node->input(0).get_source_output())); replace_node(node, make_shared<op::v1::LogicalNot>(node->input(0).get_source_output()));
modified = true; modified = true;
break; break;
} }
case OP_TYPEID::NotEqual:
{
upgrade_binary_elementwise_node<op::v0::NotEqual, op::v1::NotEqual>(node);
modified = true;
break;
}
case OP_TYPEID::Or: case OP_TYPEID::Or:
{ {
const auto or_v0 = dynamic_cast<const op::v0::Or*>(node.get()); upgrade_binary_elementwise_node<op::v0::Or, op::v1::LogicalOr>(node);
auto replacement_node = make_shared<op::v1::LogicalOr>(node->input(0).get_source_output(),
node->input(1).get_source_output(),
or_v0->get_autob());
replace_node(node, replacement_node);
modified = true; modified = true;
break; break;
} }
...@@ -408,6 +482,12 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node) ...@@ -408,6 +482,12 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
modified = true; modified = true;
break; break;
} }
case OP_TYPEID::Power:
{
upgrade_binary_elementwise_node<op::v0::Power, op::v1::Power>(node);
modified = true;
break;
}
case OP_TYPEID::Product: case OP_TYPEID::Product:
{ {
bool keep_dims = false; bool keep_dims = false;
......
...@@ -21,19 +21,30 @@ ...@@ -21,19 +21,30 @@
#include "ngraph/code_writer.hpp" #include "ngraph/code_writer.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/and.hpp" #include "ngraph/op/and.hpp"
#include "ngraph/op/avg_pool.hpp" #include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/op/convolution.hpp" #include "ngraph/op/convolution.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/equal.hpp"
#include "ngraph/op/experimental/generate_mask.hpp" #include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/gather.hpp" #include "ngraph/op/gather.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/less_eq.hpp" #include "ngraph/op/less_eq.hpp"
#include "ngraph/op/max.hpp" #include "ngraph/op/max.hpp"
#include "ngraph/op/max_pool.hpp" #include "ngraph/op/max_pool.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/min.hpp" #include "ngraph/op/min.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/not.hpp" #include "ngraph/op/not.hpp"
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/or.hpp" #include "ngraph/op/or.hpp"
#include "ngraph/op/pad.hpp" #include "ngraph/op/pad.hpp"
#include "ngraph/op/power.hpp"
#include "ngraph/op/product.hpp" #include "ngraph/op/product.hpp"
#include "ngraph/op/reverse.hpp" #include "ngraph/op/reverse.hpp"
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
...@@ -55,7 +66,6 @@ namespace ngraph ...@@ -55,7 +66,6 @@ namespace ngraph
{ {
namespace op namespace op
{ {
class Add;
class AllReduce; class AllReduce;
class BroadcastDistributed; class BroadcastDistributed;
class MatmulBias; class MatmulBias;
...@@ -69,23 +79,14 @@ namespace ngraph ...@@ -69,23 +79,14 @@ namespace ngraph
class BatchNormInferenceRelu; class BatchNormInferenceRelu;
class BatchNormTrainingBackprop; class BatchNormTrainingBackprop;
class Dot; class Dot;
class Multiply;
class GetOutputElement; class GetOutputElement;
class Abs; class Abs;
class Concat; class Concat;
class Divide;
class Equal;
class Greater;
class GreaterEq;
class Less;
class Any; class Any;
class All; class All;
class LRN; class LRN;
class Log; class Log;
class Maximum;
class Minimum;
class Negative; class Negative;
class NotEqual;
class Select; class Select;
class Subtract; class Subtract;
class Convert; class Convert;
...@@ -107,7 +108,6 @@ namespace ngraph ...@@ -107,7 +108,6 @@ namespace ngraph
class GatherND; class GatherND;
class ScatterAdd; class ScatterAdd;
class ScatterNDAdd; class ScatterNDAdd;
class Power;
class UpdateSlice; class UpdateSlice;
class ReplaceSlice; class ReplaceSlice;
class OneHot; class OneHot;
......
This diff is collapsed.
...@@ -69,13 +69,13 @@ set(SRC ...@@ -69,13 +69,13 @@ set(SRC
node_input_output.cpp node_input_output.cpp
nop_elimination.cpp nop_elimination.cpp
op.cpp op.cpp
opset_pass/binary_elementwise_opset_pass.cpp
opset_pass/broadcast_opset_pass.cpp opset_pass/broadcast_opset_pass.cpp
opset_pass/convolution_opset_pass.cpp opset_pass/convolution_opset_pass.cpp
opset_pass/dyn_reshape_opset_pass.cpp opset_pass/dyn_reshape_opset_pass.cpp
opset_pass/logical_and_opset_pass.cpp opset_pass/logical_and_opset_pass.cpp
opset_pass/logical_not_opset_pass.cpp opset_pass/logical_not_opset_pass.cpp
opset_pass/logical_or_opset_pass.cpp opset_pass/logical_or_opset_pass.cpp
opset_pass/logical_less_equal_opset_pass.cpp
opset_pass/logical_xor_opset_pass.cpp opset_pass/logical_xor_opset_pass.cpp
opset_pass/gather_opset_pass.cpp opset_pass/gather_opset_pass.cpp
opset_pass/generate_mask_opset_pass.cpp opset_pass/generate_mask_opset_pass.cpp
......
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/pass/opset1_upgrade.hpp"
#include "util/test_control.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
//------------------------------------------------------------------------------
//
// Helper Functions
//
//------------------------------------------------------------------------------
template <typename OpV0, typename OpV1>
void test_type_prop_opset0_downgrade_pass(const element::Type& output_type,
const element::Type& input_type = element::f32,
const string node_name = "")
{
auto A = make_shared<op::Parameter>(input_type, Shape{1, 3, 2});
auto B = make_shared<op::Parameter>(input_type, Shape{1, 2});
const op::AutoBroadcastSpec np_auto_b = op::AutoBroadcastSpec(op::AutoBroadcastType::NUMPY);
auto v1_node = make_shared<OpV1>(A, B);
auto result = make_shared<op::Result>(v1_node);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{A, B});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
auto v0_result = f->get_results().at(0);
auto node = v0_result->input(0).get_source_output().get_node_shared_ptr();
auto v0_node = static_pointer_cast<OpV0>(node);
EXPECT_EQ(v0_node->description(), (node_name.empty() ? v1_node->description() : node_name));
EXPECT_EQ(v0_node->get_version(), 0);
EXPECT_EQ(v0_node->get_autob(), np_auto_b);
EXPECT_EQ(v0_node->output(0).get_element_type(), output_type);
EXPECT_EQ(v0_node->output(0).get_shape(), (Shape{1, 3, 2}));
}
template <typename OpV0, typename OpV1>
void test_opset0_arithmetic_downgrade_pass()
{
test_type_prop_opset0_downgrade_pass<OpV0, OpV1>(element::f32);
}
template <typename OpV0, typename OpV1>
void test_opset0_comparison_downgrade_pass()
{
test_type_prop_opset0_downgrade_pass<OpV0, OpV1>(element::boolean);
}
template <typename OpV0, typename OpV1>
void test_type_prop_opset1_upgrade_pass(const element::Type& output_type,
const element::Type& input_type = element::f32,
const string node_name = "")
{
auto A = make_shared<op::Parameter>(input_type, Shape{1, 3, 2});
auto B = make_shared<op::Parameter>(input_type, Shape{1, 3, 2});
const op::AutoBroadcastSpec none_auto_b = op::AutoBroadcastSpec(op::AutoBroadcastType::NONE);
auto v0_node = make_shared<OpV0>(A, B);
auto result = make_shared<op::Result>(v0_node);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{A, B});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset1Upgrade>();
pass_manager.run_passes(f);
auto v1_result = f->get_results().at(0);
auto node = v1_result->input(0).get_source_output().get_node_shared_ptr();
auto v1_node = static_pointer_cast<OpV1>(node);
EXPECT_EQ(v1_node->description(), (node_name.empty() ? v0_node->description() : node_name));
EXPECT_EQ(v1_node->get_version(), 1);
EXPECT_EQ(v1_node->get_autob(), none_auto_b);
EXPECT_EQ(v1_node->output(0).get_element_type(), output_type);
EXPECT_EQ(v1_node->output(0).get_shape(), (Shape{1, 3, 2}));
}
template <typename OpV0, typename OpV1>
void test_opset1_arithmetic_upgrade_pass()
{
test_type_prop_opset1_upgrade_pass<OpV0, OpV1>(element::f32);
}
template <typename OpV0, typename OpV1>
void test_opset1_comparison_upgrade_pass()
{
test_type_prop_opset1_upgrade_pass<OpV0, OpV1>(element::boolean);
}
//------------------------------------------------------------------------------
//
// Test Cases
//
//------------------------------------------------------------------------------
TEST(opset_transform, opset0_add_downgrade_pass)
{
test_opset0_arithmetic_downgrade_pass<op::v0::Add, op::v1::Add>();
}
TEST(opset_transform, opset1_add_upgrade_pass)
{
test_opset1_arithmetic_upgrade_pass<op::v0::Add, op::v1::Add>();
}
TEST(opset_transform, opset0_divide_downgrade_pass)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{1, 3, 2});
auto B = make_shared<op::Parameter>(element::f32, Shape{1, 2});
const op::AutoBroadcastSpec np_auto_b = op::AutoBroadcastSpec(op::AutoBroadcastType::NUMPY);
const bool pydiv = false;
auto divide_v1 = make_shared<op::v1::Divide>(A, B);
divide_v1->set_is_pythondiv(pydiv);
auto result = make_shared<op::Result>(divide_v1);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{A, B});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
auto divide_v0_result = f->get_results().at(0);
auto node = divide_v0_result->input(0).get_source_output().get_node_shared_ptr();
auto divide_v0_node = static_pointer_cast<op::v0::Divide>(node);
EXPECT_EQ(divide_v0_node->description(), "Divide");
EXPECT_EQ(divide_v0_node->get_version(), 0);
EXPECT_EQ(divide_v0_node->is_pythondiv(), pydiv);
EXPECT_EQ(divide_v0_node->get_autob(), np_auto_b);
EXPECT_EQ(divide_v0_node->output(0).get_element_type(), element::f32);
EXPECT_EQ(divide_v0_node->output(0).get_shape(), (Shape{1, 3, 2}));
}
TEST(opset_transform, opset1_divide_upgrade_pass)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{1, 3, 2});
auto B = make_shared<op::Parameter>(element::f32, Shape{1, 3, 2});
const op::AutoBroadcastSpec none_auto_b = op::AutoBroadcastSpec(op::AutoBroadcastType::NONE);
const bool pydiv = false;
auto div_v0 = make_shared<op::v0::Divide>(A, B, pydiv);
auto result = make_shared<op::Result>(div_v0);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{A, B});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset1Upgrade>();
pass_manager.run_passes(f);
auto divide_v1_result = f->get_results().at(0);
auto node = divide_v1_result->input(0).get_source_output().get_node_shared_ptr();
auto divide_v1_node = static_pointer_cast<op::v1::Divide>(node);
EXPECT_EQ(divide_v1_node->description(), "Divide");
EXPECT_EQ(divide_v1_node->get_version(), 1);
EXPECT_EQ(divide_v1_node->is_pythondiv(), pydiv);
EXPECT_EQ(divide_v1_node->get_autob(), none_auto_b);
EXPECT_EQ(divide_v1_node->output(0).get_element_type(), element::f32);
EXPECT_EQ(divide_v1_node->output(0).get_shape(), (Shape{1, 3, 2}));
}
TEST(opset_transform, opset0_equal_downgrade_pass)
{
test_opset0_comparison_downgrade_pass<op::v0::Equal, op::v1::Equal>();
}
TEST(opset_transform, opset1_equal_upgrade_pass)
{
test_opset1_comparison_upgrade_pass<op::v0::Equal, op::v1::Equal>();
}
TEST(opset_transform, opset0_greater_downgrade_pass)
{
test_opset0_comparison_downgrade_pass<op::v0::Greater, op::v1::Greater>();
}
TEST(opset_transform, opset1_greater_upgrade_pass)
{
test_opset1_comparison_upgrade_pass<op::v0::Greater, op::v1::Greater>();
}
TEST(opset_transform, opset0_greater_eq_downgrade_pass)
{
test_opset0_comparison_downgrade_pass<op::v0::GreaterEq, op::v1::GreaterEq>();
}
TEST(opset_transform, opset1_greater_eq_upgrade_pass)
{
test_opset1_comparison_upgrade_pass<op::v0::GreaterEq, op::v1::GreaterEq>();
}
TEST(opset_transform, opset0_less_downgrade_pass)
{
test_opset0_comparison_downgrade_pass<op::v0::Less, op::v1::Less>();
}
TEST(opset_transform, opset1_less_upgrade_pass)
{
test_opset1_comparison_upgrade_pass<op::v0::Less, op::v1::Less>();
}
TEST(opset_transform, opset0_less_eq_downgrade_pass)
{
test_type_prop_opset0_downgrade_pass<op::v0::LessEq, op::v1::LessEqual>(
element::boolean, element::f32, "LessEq");
}
TEST(opset_transform, opset1_less_eq_upgrade_pass)
{
test_type_prop_opset1_upgrade_pass<op::v0::LessEq, op::v1::LessEqual>(
element::boolean, element::f32, "LessEqual");
}
TEST(opset_transform, opset0_maximum_downgrade_pass)
{
test_opset0_arithmetic_downgrade_pass<op::v0::Maximum, op::v1::Maximum>();
}
TEST(opset_transform, opset1_maximum_upgrade_pass)
{
test_opset1_arithmetic_upgrade_pass<op::v0::Maximum, op::v1::Maximum>();
}
TEST(opset_transform, opset0_minimum_downgrade_pass)
{
test_opset0_arithmetic_downgrade_pass<op::v0::Minimum, op::v1::Minimum>();
}
TEST(opset_transform, opset1_minimum_upgrade_pass)
{
test_opset1_arithmetic_upgrade_pass<op::v0::Minimum, op::v1::Minimum>();
}
TEST(opset_transform, opset0_multiply_downgrade_pass)
{
test_opset0_arithmetic_downgrade_pass<op::v0::Multiply, op::v1::Multiply>();
}
TEST(opset_transform, opset1_multiply_upgrade_pass)
{
test_opset1_arithmetic_upgrade_pass<op::v0::Multiply, op::v1::Multiply>();
}
TEST(opset_transform, opset0_not_equal_downgrade_pass)
{
test_opset0_comparison_downgrade_pass<op::v0::NotEqual, op::v1::NotEqual>();
}
TEST(opset_transform, opset1_not_equal_upgrade_pass)
{
test_opset1_comparison_upgrade_pass<op::v0::NotEqual, op::v1::NotEqual>();
}
TEST(opset_transform, opset0_power_downgrade_pass)
{
test_opset0_arithmetic_downgrade_pass<op::v0::Power, op::v1::Power>();
}
TEST(opset_transform, opset1_power_upgrade_pass)
{
test_opset1_arithmetic_upgrade_pass<op::v0::Power, op::v1::Power>();
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/pass/opset1_upgrade.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(opset_transform, opset1_logical_less_equal_upgrade_pass)
{
const auto a = make_shared<op::Parameter>(element::boolean, Shape{5, 10, 15});
const auto b = make_shared<op::Parameter>(element::boolean, Shape{5, 10, 15});
const auto less_eq_v0 = make_shared<op::v0::LessEq>(a, b);
const auto result = make_shared<op::Result>(less_eq_v0);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{a, b});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset1Upgrade>();
pass_manager.run_passes(f);
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto less_eq_v1 = static_pointer_cast<op::v1::LessEqual>(pass_replacement_node);
EXPECT_EQ(less_eq_v1->description(), "LessEqual");
EXPECT_EQ(less_eq_v1->get_version(), 1);
const auto values_out_element_type = less_eq_v1->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean);
}
TEST(opset_transform, opset1_logical_less_equal_downgrade_pass)
{
const auto a = make_shared<op::Parameter>(element::boolean, Shape{5, 10, 15});
const auto b = make_shared<op::Parameter>(element::boolean, Shape{5, 10, 15});
const auto less_eq_v1 = make_shared<op::v1::LessEqual>(a, b);
const auto result = make_shared<op::Result>(less_eq_v1);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{a, b});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto less_eq_v0 = static_pointer_cast<op::v0::LessEq>(pass_replacement_node);
EXPECT_EQ(less_eq_v0->description(), "LessEq");
EXPECT_EQ(less_eq_v0->get_version(), 0);
const auto values_out_element_type = less_eq_v0->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment