Commit f5b322cf authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Scott Cyphers

[SPEC] Numpy AutoBroadcast as default for specific ops. (#3816)

* Add upgrade/downgrade pass for Add op.

* Add upgrade/downgrade pass for Divide op.

- Change default value of autobradcasting in v1 into NUMPY.

* Add v1 version for Equal operator.

* Rename helper functions to fix compiler errros.

Fix Divide op version.

* Add downgrade and upgrade passes for Equal op.

* Reformat test cases. Add helper functions.

Add UT for Equal op.

* Add upgrade/downgrade pass and UT for Greater op.

* Add upgrade/downgrade pass and UT for GreaterEq op.

* Add upgrade/downgrade pass and UT for Less op.

* Add upgrade/downgrade pass and UT for LessEq op.

* Add upgrade/downgrade pass and UT for Maximum op.

* Add upgrade/downgrade pass and UT for Minimum op.

* Add upgrade/downgrade pass and UT for Multiply op.

* Add upgrade/downgrade pass and UT for NotEqual op.

* Add upgrade/downgrade pass and UT for Power op.

* Force ops version 1.

* Don't inline templates.

* Fix namespaces and some formatting.

* Update ONNX Importer to produce v1 nGraph nodes.

* Fix function return type.

* Fix uninitialized local variable warning.

* Fix confilicting declarations.

* Apply clang-format.

* Fix errors for distributed nGraph with unavailable classes.

* Fix downgrade pass for LessEqual op.
parent 1039ea1a
......@@ -47,10 +47,8 @@ namespace ngraph
{
inline NodeVector add(const Node& node)
{
return {std::make_shared<ngraph::op::Add>(
node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
return {std::make_shared<ngraph::op::v1::Add>(node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1))};
}
} // namespace set_7
......
......@@ -47,10 +47,8 @@ namespace ngraph
{
inline NodeVector div(const Node& node)
{
return {std::make_shared<ngraph::op::Divide>(
node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
return {std::make_shared<ngraph::op::v1::Divide>(node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1))};
}
} // namespace set_1
......
......@@ -31,10 +31,8 @@ namespace ngraph
{
inline NodeVector equal(const Node& node)
{
return {std::make_shared<ngraph::op::Equal>(
node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
return {std::make_shared<ngraph::op::v1::Equal>(node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1))};
}
} // namespace set_1
......
......@@ -31,10 +31,8 @@ namespace ngraph
{
inline NodeVector greater(const Node& node)
{
return {std::make_shared<ngraph::op::Greater>(
node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
return {std::make_shared<ngraph::op::v1::Greater>(node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1))};
}
} // namespace set_1
......
......@@ -46,9 +46,9 @@ namespace ngraph
<< " alpha value should be in range (0,1)";
std::shared_ptr<ngraph::Node> alpha_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), Shape{}, std::vector<double>{alpha});
alpha_node = ngraph::op::make_broadcast_node(alpha_node, data->get_shape());
std::make_shared<ngraph::op::Constant>(data->get_element_type(),
data->get_shape(),
std::vector<double>{alpha});
return {std::make_shared<ngraph::op::Maximum>(data * alpha_node, data)};
}
......
......@@ -31,10 +31,8 @@ namespace ngraph
{
inline NodeVector less(const Node& node)
{
return {std::make_shared<ngraph::op::Less>(
node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
return {std::make_shared<ngraph::op::v1::Less>(node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1))};
}
} // namespace set_1
......
......@@ -40,7 +40,7 @@ namespace ngraph
{
inline NodeVector max(const Node& node)
{
return variadic::make_ng_variadic_op_with_broadcast<ngraph::op::Maximum>(node);
return variadic::make_ng_variadic_op<ngraph::op::v1::Maximum>(node);
}
} // namespace set_8
......
......@@ -49,8 +49,7 @@ namespace ngraph
{
NodeVector mean(const Node& node)
{
auto sum =
variadic::make_ng_variadic_op_with_broadcast<ngraph::op::Add>(node).front();
auto sum = variadic::make_ng_variadic_op<ngraph::op::v1::Add>(node).front();
auto shape = sum->get_shape();
// Create a Constant representing the number of inputs with the same shape as
......
......@@ -40,7 +40,7 @@ namespace ngraph
{
inline NodeVector min(const Node& node)
{
return variadic::make_ng_variadic_op_with_broadcast<ngraph::op::Minimum>(node);
return variadic::make_ng_variadic_op<ngraph::op::v1::Minimum>(node);
}
} // namespace set_8
......
......@@ -49,10 +49,8 @@ namespace ngraph
{
inline NodeVector mul(const Node& node)
{
return {std::make_shared<ngraph::op::Multiply>(
node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
return {std::make_shared<ngraph::op::v1::Multiply>(node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1))};
}
} // namespace set_7
......
......@@ -31,10 +31,8 @@ namespace ngraph
{
inline NodeVector pow(const Node& node)
{
return {std::make_shared<ngraph::op::Power>(
node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
return {std::make_shared<ngraph::op::v1::Power>(node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1))};
}
} // namespace set_1
......
......@@ -70,7 +70,7 @@ namespace ngraph
std::shared_ptr<ngraph::Node> values_below_neg_lambd =
std::make_shared<ngraph::op::Less>(input, negative_lambd);
std::shared_ptr<ngraph::Node> values_above_pos_lambd =
std::make_shared<ngraph::op::Greater>(input, positive_lambd);
std::make_shared<ngraph::op::v1::Greater>(input, positive_lambd);
// Convert from bool to the input type to be able to multiply adjusted inputs
// by the created masks
......
......@@ -16,9 +16,9 @@
#include <memory>
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/log.hpp"
......@@ -40,9 +40,10 @@ namespace ngraph
auto data = node.get_ng_inputs().at(0);
std::shared_ptr<ngraph::Node> zero_node =
builder::make_constant(data->get_element_type(), data->get_shape(), 0.f);
std::shared_ptr<ngraph::Node> one_node =
builder::make_constant(data->get_element_type(), data->get_shape(), 1.f);
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), data->get_shape(), std::vector<float>{0.f});
std::shared_ptr<ngraph::Node> one_node = std::make_shared<ngraph::op::Constant>(
data->get_element_type(), data->get_shape(), std::vector<float>{1.f});
std::shared_ptr<ngraph::Node> positive_val_node =
data + std::make_shared<ngraph::op::Log>(
......
......@@ -40,7 +40,7 @@ namespace ngraph
{
inline NodeVector sum(const Node& node)
{
return variadic::make_ng_variadic_op_with_broadcast<ngraph::op::Add>(node);
return variadic::make_ng_variadic_op<ngraph::op::v1::Add>(node);
}
} // namespace set_8
......
......@@ -38,9 +38,9 @@ namespace ngraph
double alpha = node.get_attribute_value<double>("alpha", 1.0);
std::shared_ptr<ngraph::Node> alpha_node =
std::make_shared<ngraph::op::Constant>(
data->get_element_type(), ngraph::Shape{}, std::vector<double>{alpha});
alpha_node = ngraph::op::make_broadcast_node(alpha_node, data->get_shape());
std::make_shared<ngraph::op::Constant>(data->get_element_type(),
data->get_shape(),
std::vector<double>{alpha});
auto data_map = std::make_shared<ngraph::op::Convert>(
std::make_shared<ngraph::op::Greater>(data, alpha_node),
......
......@@ -60,36 +60,6 @@ namespace ngraph
return {result};
}
/// \brief Create an nGraph version of an ONNX variadic operation.
/// This creates a subgraph with a series of binary operations.
///
/// \param node Incoming ONNX opearation.
///
/// \tparam T Class of an nGraph binary operation (e.g. Add, Minimum, Maximum)
///
/// \return nGraph node equivalent of the ONNX operation
template <class T>
inline NodeVector make_ng_variadic_op_with_broadcast(const Node& node)
{
NodeVector ng_inputs{node.get_ng_inputs()};
// Templated binary operation - Creates Add, Minimum, Maximum, etc.
auto binary_operation = [](const std::shared_ptr<ngraph::Node>& arg0,
const std::shared_ptr<ngraph::Node>& arg1) {
NodeVector args{ngraph::op::numpy_style_broadcast({arg0, arg1})};
return std::make_shared<T>(args.at(0), args.at(1));
};
// Create a result node as a series of binary operations
auto result = std::accumulate(
std::next(std::begin(ng_inputs)), // First operand value - the second input
std::end(ng_inputs), // Last value - final input
ng_inputs.front(), // Initial value - first input
binary_operation);
return {result};
}
} // namespace variadic
} // namespace onnx_import
......
......@@ -19,23 +19,25 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::Add::type_info;
// ------------------------------- v0 ------------------------------------------
op::Add::Add(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
constexpr NodeTypeInfo op::v0::Add::type_info;
op::v0::Add::Add(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::Add::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v0::Add::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Add>(new_args.at(0), new_args.at(1), this->get_autob());
return make_shared<op::v0::Add>(new_args.at(0), new_args.at(1), this->get_autob());
}
void op::Add::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
void op::v0::Add::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
if (get_autob().m_type != op::AutoBroadcastType::NONE)
{
......@@ -55,3 +57,37 @@ shared_ptr<Node> ngraph::operator+(const Output<Node>& arg0, const Output<Node>&
{
return make_shared<op::Add>(arg0, arg1);
}
// ------------------------------- v1 ------------------------------------------
constexpr NodeTypeInfo op::v1::Add::type_info;
op::v1::Add::Add(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v1::Add::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::v1::Add>(new_args.at(0), new_args.at(1), this->get_autob());
}
void op::v1::Add::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
if (get_autob().m_type != op::AutoBroadcastType::NONE)
{
throw ngraph_error("Autodiff not supported with auto broadcasting");
}
auto delta = deltas.at(0);
auto x = input_value(0);
auto y = input_value(1);
adjoints.add_delta(x, delta);
adjoints.add_delta(y, delta);
}
......@@ -24,39 +24,83 @@ namespace ngraph
{
namespace op
{
/// \brief Elementwise addition operation.
///
class Add : public util::BinaryElementwiseArithmetic
namespace v0
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Add", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs an uninitialized addition operation
Add() = default;
/// \brief Constructs an addition operation.
/// \brief Elementwise addition operation.
///
/// \param arg0 Output that produces the first input tensor.<br>
/// `[d0, ...]`
/// \param arg1 Output that produces the second input tensor.<br>
/// `[d0, ...]`
/// \param auto_broadcast Auto broadcast specification
///
/// Output `[d0, ...]`
class Add : public util::BinaryElementwiseArithmetic
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Add", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs an uninitialized addition operation
Add() = default;
/// \brief Constructs an addition operation.
///
/// \param arg0 Output that produces the first input tensor.<br>
/// `[d0, ...]`
/// \param arg1 Output that produces the second input tensor.<br>
/// `[d0, ...]`
/// \param auto_broadcast Auto broadcast specification
///
/// Output `[d0, ...]`
///
Add(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
} // namespace v0
namespace v1
{
/// \brief Elementwise addition operation.
///
Add(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
class Add : public util::BinaryElementwiseArithmetic
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Add", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs an uninitialized addition operation
Add() = default;
/// \brief Constructs an addition operation.
///
/// \param arg0 Output that produces the first input tensor.<br>
/// `[d0, ...]`
/// \param arg1 Output that produces the second input tensor.<br>
/// `[d0, ...]`
/// \param auto_broadcast Auto broadcast specification. Default is Numpy-style
/// implicit broadcasting.
///
/// Output `[d0, ...]`
///
Add(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
size_t get_version() const override { return 1; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
} // namespace v1
virtual bool is_commutative() const override { return true; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
}
using v0::Add;
} // namespace op
std::shared_ptr<Node> operator+(const Output<Node>& arg0, const Output<Node>& arg1);
}
} // namespace ngraph
......@@ -21,34 +21,36 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::Divide::type_info;
// ------------------------------ v0 -------------------------------------------
op::Divide::Divide(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
constexpr NodeTypeInfo op::v0::Divide::type_info;
op::v0::Divide::Divide(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
op::Divide::Divide(const Output<Node>& arg0,
const Output<Node>& arg1,
bool pythondiv,
const AutoBroadcastSpec& auto_broadcast)
op::v0::Divide::Divide(const Output<Node>& arg0,
const Output<Node>& arg1,
bool pythondiv,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast)
, m_pythondiv(pythondiv)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::Divide::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v0::Divide::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Divide>(
return make_shared<op::v0::Divide>(
new_args.at(0), new_args.at(1), this->is_pythondiv(), this->get_autob());
}
void op::Divide::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
void op::v0::Divide::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
if (get_autob().m_type != op::AutoBroadcastType::NONE)
{
......@@ -66,5 +68,50 @@ void op::Divide::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVecto
shared_ptr<Node> ngraph::operator/(const Output<Node>& arg0, const Output<Node>& arg1)
{
return make_shared<op::Divide>(arg0, arg1);
return make_shared<op::v0::Divide>(arg0, arg1);
}
// ------------------------------ v1 -------------------------------------------
constexpr NodeTypeInfo op::v1::Divide::type_info;
op::v1::Divide::Divide(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
op::v1::Divide::Divide(const Output<Node>& arg0,
const Output<Node>& arg1,
bool pythondiv,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast)
, m_pythondiv(pythondiv)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v1::Divide::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::v1::Divide>(
new_args.at(0), new_args.at(1), this->is_pythondiv(), this->get_autob());
}
void op::v1::Divide::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
if (get_autob().m_type != op::AutoBroadcastType::NONE)
{
throw ngraph_error("Autodiff not supported with auto broadcasting");
}
auto delta = deltas.at(0);
auto x = input_value(0);
auto y = input_value(1);
adjoints.add_delta(x, delta / y);
adjoints.add_delta(y, -delta * shared_from_this() / y);
}
......@@ -22,47 +22,98 @@ namespace ngraph
{
namespace op
{
/// \brief Elementwise division operation.
class Divide : public util::BinaryElementwiseArithmetic
namespace v0
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Divide", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a division operation.
Divide() = default;
/// \brief Constructs a division operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param pythondiv Use Python style rounding for integral type
/// \param auto_broadcast Auto broadcast specification
Divide(const Output<Node>& arg0,
const Output<Node>& arg1,
bool pythondiv,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
/// \brief Elementwise division operation.
class Divide : public util::BinaryElementwiseArithmetic
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Divide", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a division operation.
Divide() = default;
/// \brief Constructs a division operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param pythondiv Use Python style rounding for integral type
/// \param auto_broadcast Auto broadcast specification
Divide(const Output<Node>& arg0,
const Output<Node>& arg1,
bool pythondiv,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
/// \brief Constructs a division operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Divide(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
/// \brief Constructs a division operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Divide(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
bool is_pythondiv() const { return m_pythondiv; }
void set_is_pythondiv(bool pythondiv) { m_pythondiv = pythondiv; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
bool is_pythondiv() const { return m_pythondiv; }
void set_is_pythondiv(bool pythondiv) { m_pythondiv = pythondiv; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
protected:
bool m_pythondiv{true};
};
}
protected:
bool m_pythondiv{true};
};
} // namespace v0
namespace v1
{
/// \brief Elementwise division operation.
class Divide : public util::BinaryElementwiseArithmetic
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Divide", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a division operation.
Divide() = default;
/// \brief Constructs a division operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param pythondiv Use Python style rounding for integral type
/// \param auto_broadcast Auto broadcast specification
Divide(const Output<Node>& arg0,
const Output<Node>& arg1,
bool pythondiv,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
/// \brief Constructs a division operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Divide(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
bool is_pythondiv() const { return m_pythondiv; }
void set_is_pythondiv(bool pythondiv) { m_pythondiv = pythondiv; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
size_t get_version() const override { return 1; }
protected:
bool m_pythondiv{true};
};
} // namespace v1
using v0::Divide;
} // namespace op
std::shared_ptr<Node> operator/(const Output<Node>& arg0, const Output<Node>& arg1);
}
} // namespace ngraph
......@@ -19,18 +19,38 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::Equal::type_info;
//------------------------------- v0 -------------------------------------------
op::Equal::Equal(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
constexpr NodeTypeInfo op::v0::Equal::type_info;
op::v0::Equal::Equal(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseComparison(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v0::Equal::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::v0::Equal>(new_args.at(0), new_args.at(1), this->get_autob());
}
//------------------------------- v1 -------------------------------------------
constexpr NodeTypeInfo op::v1::Equal::type_info;
op::v1::Equal::Equal(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseComparison(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::Equal::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v1::Equal::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Equal>(new_args.at(0), new_args.at(1), this->get_autob());
return make_shared<op::v1::Equal>(new_args.at(0), new_args.at(1), this->get_autob());
}
......@@ -22,44 +22,94 @@ namespace ngraph
{
namespace op
{
// clang-format off
/// \brief Elementwise is-equal operation.
///
/// ## Inputs
///
/// | | Type | Description |
/// | ------ | --------------------------------- | ------------------------------------------------------ |
/// | `arg0` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape and element type. |
/// | `arg1` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of the same shape and element type as `arg0`. |
/// | `autob`| AutoBroadcastSpec | Auto broadcast specification. |
///
/// ## Output
///
/// | Type | Description |
/// | ---------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------ |
/// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg0}[i_1,\dots,i_n] = \texttt{arg1}[i_1,\dots,i_n]\text{, else } 0\f$ |
// clang-format on
class Equal : public util::BinaryElementwiseComparison
namespace v0
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Equal", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs an equal operation.
Equal() = default;
/// \brief Constructs an equal operation.
// clang-format off
/// \brief Elementwise is-equal operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Equal(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
/// ## Inputs
///
/// | | Type | Description |
/// | ------ | --------------------------------- | ------------------------------------------------------ |
/// | `arg0` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape and element type. |
/// | `arg1` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of the same shape and element type as `arg0`. |
/// | `autob`| AutoBroadcastSpec | Auto broadcast specification. |
///
/// ## Output
///
/// | Type | Description |
/// | ---------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------ |
/// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg0}[i_1,\dots,i_n] = \texttt{arg1}[i_1,\dots,i_n]\text{, else } 0\f$ |
// clang-format on
class Equal : public util::BinaryElementwiseComparison
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Equal", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs an equal operation.
Equal() = default;
/// \brief Constructs an equal operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Equal(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
};
} // namespace v0
namespace v1
{
// clang-format off
/// \brief Elementwise is-equal operation.
///
/// ## Inputs
///
/// | | Type | Description |
/// | ------ | --------------------------------- | ------------------------------------------------------ |
/// | `arg0` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape and element type. |
/// | `arg1` | \f$E[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of the same shape and element type as `arg0`. |
/// | `autob`| AutoBroadcastSpec | Auto broadcast specification. |
///
/// ## Output
///
/// | Type | Description |
/// | ---------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------ |
/// | \f$\texttt{bool}[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = 1\text{ if }\texttt{arg0}[i_1,\dots,i_n] = \texttt{arg1}[i_1,\dots,i_n]\text{, else } 0\f$ |
// clang-format on
class Equal : public util::BinaryElementwiseComparison
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Equal", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs an equal operation.
Equal() = default;
/// \brief Constructs an equal operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Equal(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
size_t get_version() const override { return 1; }
};
} // namespace v1
virtual bool is_commutative() const override { return true; }
};
using v0::Equal;
}
}
......@@ -19,18 +19,38 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::Greater::type_info;
//-------------------------------------- v0 ------------------------------------
op::Greater::Greater(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
constexpr NodeTypeInfo op::v0::Greater::type_info;
op::v0::Greater::Greater(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseComparison(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v0::Greater::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::v0::Greater>(new_args.at(0), new_args.at(1), this->get_autob());
}
//-------------------------------------- v1 ------------------------------------
constexpr NodeTypeInfo op::v1::Greater::type_info;
op::v1::Greater::Greater(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseComparison(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::Greater::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v1::Greater::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Greater>(new_args.at(0), new_args.at(1), this->get_autob());
return make_shared<op::v1::Greater>(new_args.at(0), new_args.at(1), this->get_autob());
}
......@@ -22,26 +22,58 @@ namespace ngraph
{
namespace op
{
/// \brief Elementwise greater-than operation.
class Greater : public util::BinaryElementwiseComparison
namespace v0
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Greater", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a greater-than operation.
Greater() = default;
/// \brief Constructs a greater-than operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Greater(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
/// \brief Elementwise greater-than operation.
class Greater : public util::BinaryElementwiseComparison
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Greater", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a greater-than operation.
Greater() = default;
/// \brief Constructs a greater-than operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Greater(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
} // namespace v0
namespace v1
{
/// \brief Elementwise greater-than operation.
class Greater : public util::BinaryElementwiseComparison
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Greater", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a greater-than operation.
Greater() = default;
/// \brief Constructs a greater-than operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Greater(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
size_t get_version() const override { return 1; }
};
} // namespace v1
using v0::Greater;
}
}
......@@ -19,18 +19,38 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::GreaterEq::type_info;
//---------------------------------- v0 ----------------------------------------
op::GreaterEq::GreaterEq(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
constexpr NodeTypeInfo op::v0::GreaterEq::type_info;
op::v0::GreaterEq::GreaterEq(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseComparison(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v0::GreaterEq::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::v0::GreaterEq>(new_args.at(0), new_args.at(1), this->get_autob());
}
//---------------------------------- v1 ----------------------------------------
constexpr NodeTypeInfo op::v1::GreaterEq::type_info;
op::v1::GreaterEq::GreaterEq(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseComparison(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::GreaterEq::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v1::GreaterEq::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<GreaterEq>(new_args.at(0), new_args.at(1), this->get_autob());
return make_shared<op::v1::GreaterEq>(new_args.at(0), new_args.at(1), this->get_autob());
}
......@@ -22,26 +22,58 @@ namespace ngraph
{
namespace op
{
/// \brief Elementwise greater-than-or-equal operation.
class GreaterEq : public util::BinaryElementwiseComparison
namespace v0
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"GreaterEq", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a greater-than-or-equal operation.
GreaterEq() = default;
/// \brief Constructs a greater-than-or-equal operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
GreaterEq(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
/// \brief Elementwise greater-than-or-equal operation.
class GreaterEq : public util::BinaryElementwiseComparison
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"GreaterEq", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a greater-than-or-equal operation.
GreaterEq() = default;
/// \brief Constructs a greater-than-or-equal operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
GreaterEq(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
} // namespace v0
namespace v1
{
/// \brief Elementwise greater-than-or-equal operation.
class GreaterEq : public util::BinaryElementwiseComparison
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"GreaterEq", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a greater-than-or-equal operation.
GreaterEq() = default;
/// \brief Constructs a greater-than-or-equal operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
GreaterEq(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
size_t get_version() const override { return 1; }
};
} // namespace v1
using v0::GreaterEq;
}
}
......@@ -19,18 +19,38 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::Less::type_info;
// ----------------------------- v0 --------------------------------------------
op::Less::Less(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
constexpr NodeTypeInfo op::v0::Less::type_info;
op::v0::Less::Less(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseComparison(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v0::Less::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::v0::Less>(new_args.at(0), new_args.at(1), this->get_autob());
}
// ----------------------------- v1 --------------------------------------------
constexpr NodeTypeInfo op::v1::Less::type_info;
op::v1::Less::Less(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseComparison(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::Less::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v1::Less::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Less>(new_args.at(0), new_args.at(1), this->get_autob());
return make_shared<op::v1::Less>(new_args.at(0), new_args.at(1), this->get_autob());
}
......@@ -22,26 +22,58 @@ namespace ngraph
{
namespace op
{
/// \brief Elementwise less-than operation.
class Less : public util::BinaryElementwiseComparison
namespace v0
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Less", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a less-than operation.
Less() = default;
/// \brief Constructs a less-than operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Less(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
/// \brief Elementwise less-than operation.
class Less : public util::BinaryElementwiseComparison
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Less", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a less-than operation.
Less() = default;
/// \brief Constructs a less-than operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Less(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
} // namespace v0
namespace v1
{
/// \brief Elementwise less-than operation.
class Less : public util::BinaryElementwiseComparison
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Less", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a less-than operation.
Less() = default;
/// \brief Constructs a less-than operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Less(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
size_t get_version() const override { return 1; }
};
} // namespace v1
using v0::Less;
}
}
......@@ -19,6 +19,8 @@
using namespace std;
using namespace ngraph;
// ---------------------------------- v1 ---------------------------------------
constexpr NodeTypeInfo op::v1::LessEqual::type_info;
op::v1::LessEqual::LessEqual(const Output<Node>& arg0,
......@@ -35,6 +37,8 @@ shared_ptr<Node> op::v1::LessEqual::copy_with_new_args(const NodeVector& new_arg
return make_shared<v1::LessEqual>(new_args.at(0), new_args.at(1), this->get_autob());
}
// ---------------------------------- v0 ---------------------------------------
constexpr NodeTypeInfo op::v0::LessEq::type_info;
op::v0::LessEq::LessEq(const Output<Node>& arg0,
......
......@@ -33,6 +33,7 @@ namespace ngraph
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a less-than-or-equal operation.
LessEqual() = default;
/// \brief Constructs a less-than-or-equal operation.
///
/// \param arg0 Node that produces the first input tensor.
......@@ -40,12 +41,14 @@ namespace ngraph
/// \param auto_broadcast Auto broadcast specification
LessEqual(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
} // namespace v1
namespace v0
{
/// \brief Elementwise less-than-or-equal operation.
......@@ -57,6 +60,7 @@ namespace ngraph
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a less-than-or-equal operation.
LessEq() = default;
/// \brief Constructs a less-than-or-equal operation.
///
/// \param arg0 Node that produces the first input tensor.
......
......@@ -25,23 +25,62 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::Maximum::type_info;
// ------------------------------------ v0 -------------------------------------
op::Maximum::Maximum(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
constexpr NodeTypeInfo op::v0::Maximum::type_info;
op::v0::Maximum::Maximum(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v0::Maximum::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::v0::Maximum>(new_args.at(0), new_args.at(1), this->get_autob());
}
void op::v0::Maximum::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
if (get_autob().m_type != op::AutoBroadcastType::NONE)
{
throw ngraph_error("Autodiff not supported with auto broadcasting");
}
auto delta = deltas.at(0);
auto x = input_value(0);
auto y = input_value(1);
adjoints.add_delta(
x,
delta * make_shared<op::Convert>(make_shared<op::v0::Greater>(x, y), x.get_element_type()));
adjoints.add_delta(
y,
delta * make_shared<op::Convert>(make_shared<op::v0::Greater>(y, x), y.get_element_type()));
}
// ------------------------------------ v1 -------------------------------------
constexpr NodeTypeInfo op::v1::Maximum::type_info;
op::v1::Maximum::Maximum(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::Maximum::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v1::Maximum::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Maximum>(new_args.at(0), new_args.at(1), this->get_autob());
return make_shared<op::v1::Maximum>(new_args.at(0), new_args.at(1), this->get_autob());
}
void op::Maximum::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
void op::v1::Maximum::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
if (get_autob().m_type != op::AutoBroadcastType::NONE)
{
......@@ -53,7 +92,9 @@ void op::Maximum::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect
auto x = input_value(0);
auto y = input_value(1);
adjoints.add_delta(
x, delta * make_shared<op::Convert>(make_shared<op::Greater>(x, y), x.get_element_type()));
x,
delta * make_shared<op::Convert>(make_shared<op::v1::Greater>(x, y), x.get_element_type()));
adjoints.add_delta(
y, delta * make_shared<op::Convert>(make_shared<op::Greater>(y, x), y.get_element_type()));
y,
delta * make_shared<op::Convert>(make_shared<op::v1::Greater>(y, x), y.get_element_type()));
}
......@@ -22,31 +22,68 @@ namespace ngraph
{
namespace op
{
/// \brief Elementwise maximum operation.
class Maximum : public util::BinaryElementwiseArithmetic
namespace v0
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Maximum", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a maximum operation.
Maximum() = default;
/// \brief Constructs a maximum operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Maximum(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
/// \brief Elementwise maximum operation.
class Maximum : public util::BinaryElementwiseArithmetic
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Maximum", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a maximum operation.
Maximum() = default;
/// \brief Constructs a maximum operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Maximum(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
} // namespace v0
namespace v1
{
/// \brief Elementwise maximum operation.
class Maximum : public util::BinaryElementwiseArithmetic
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Maximum", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a maximum operation.
Maximum() = default;
/// \brief Constructs a maximum operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Maximum(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
size_t get_version() const override { return 1; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
} // namespace v1
using v0::Maximum;
}
}
......@@ -25,23 +25,61 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::Minimum::type_info;
// ------------------------------ v0 -------------------------------------------
op::Minimum::Minimum(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
constexpr NodeTypeInfo op::v0::Minimum::type_info;
op::v0::Minimum::Minimum(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v0::Minimum::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::v0::Minimum>(new_args.at(0), new_args.at(1), this->get_autob());
}
void op::v0::Minimum::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
if (get_autob().m_type != op::AutoBroadcastType::NONE)
{
throw ngraph_error("Autodiff not supported with auto broadcasting");
}
auto delta = deltas.at(0);
auto x = input_value(0);
auto y = input_value(1);
adjoints.add_delta(
x, delta * make_shared<op::Convert>(make_shared<op::v0::Less>(x, y), x.get_element_type()));
adjoints.add_delta(
y, delta * make_shared<op::Convert>(make_shared<op::v0::Less>(y, x), y.get_element_type()));
}
// ------------------------------ v1 -------------------------------------------
constexpr NodeTypeInfo op::v1::Minimum::type_info;
op::v1::Minimum::Minimum(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::Minimum::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v1::Minimum::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Minimum>(new_args.at(0), new_args.at(1), this->get_autob());
return make_shared<op::v1::Minimum>(new_args.at(0), new_args.at(1), this->get_autob());
}
void op::Minimum::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
void op::v1::Minimum::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
if (get_autob().m_type != op::AutoBroadcastType::NONE)
{
......@@ -54,7 +92,7 @@ void op::Minimum::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVect
auto y = input_value(1);
adjoints.add_delta(
x, delta * make_shared<op::Convert>(make_shared<op::Less>(x, y), x.get_element_type()));
x, delta * make_shared<op::Convert>(make_shared<op::v1::Less>(x, y), x.get_element_type()));
adjoints.add_delta(
y, delta * make_shared<op::Convert>(make_shared<op::Less>(y, x), y.get_element_type()));
y, delta * make_shared<op::Convert>(make_shared<op::v1::Less>(y, x), y.get_element_type()));
}
......@@ -22,31 +22,68 @@ namespace ngraph
{
namespace op
{
/// \brief Elementwise minimum operation.
class Minimum : public util::BinaryElementwiseArithmetic
namespace v0
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Minimum", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a minimum operation.
Minimum() = default;
/// \brief Constructs a minimum operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Minimum(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
/// \brief Elementwise minimum operation.
class Minimum : public util::BinaryElementwiseArithmetic
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Minimum", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a minimum operation.
Minimum() = default;
/// \brief Constructs a minimum operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Minimum(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
} // namespace v0
namespace v1
{
/// \brief Elementwise minimum operation.
class Minimum : public util::BinaryElementwiseArithmetic
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Minimum", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a minimum operation.
Minimum() = default;
/// \brief Constructs a minimum operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Minimum(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
size_t get_version() const override { return 1; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
} // namespace v1
using v0::Minimum;
}
}
......@@ -19,23 +19,59 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::Multiply::type_info;
// ------------------------------------ v0 -------------------------------------
op::Multiply::Multiply(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
constexpr NodeTypeInfo op::v0::Multiply::type_info;
op::v0::Multiply::Multiply(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v0::Multiply::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::v0::Multiply>(new_args.at(0), new_args.at(1), this->get_autob());
}
void op::v0::Multiply::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
if (get_autob().m_type != op::AutoBroadcastType::NONE)
{
throw ngraph_error("Autodiff not supported with auto broadcasting");
}
auto delta = deltas.at(0);
auto x = input_value(0);
auto y = input_value(1);
adjoints.add_delta(x, delta * y);
adjoints.add_delta(y, x * delta);
}
// ------------------------------------ v1 -------------------------------------
constexpr NodeTypeInfo op::v1::Multiply::type_info;
op::v1::Multiply::Multiply(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::Multiply::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v1::Multiply::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Multiply>(new_args.at(0), new_args.at(1), this->get_autob());
return make_shared<op::v1::Multiply>(new_args.at(0), new_args.at(1), this->get_autob());
}
void op::Multiply::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
void op::v1::Multiply::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
if (get_autob().m_type != op::AutoBroadcastType::NONE)
{
......@@ -51,6 +87,8 @@ void op::Multiply::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVec
adjoints.add_delta(y, x * delta);
}
// -----------------------------------------------------------------------------
shared_ptr<Node> ngraph::operator*(const Output<Node>& arg0, const Output<Node>& arg1)
{
return make_shared<op::Multiply>(arg0, arg1);
......
......@@ -22,33 +22,70 @@ namespace ngraph
{
namespace op
{
/// \brief Elementwise multiplication operation.
class Multiply : public util::BinaryElementwiseArithmetic
namespace v0
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Multiply", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a multiplication operation.
Multiply() = default;
/// \brief Constructs a multiplication operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Multiply(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
};
/// \brief Elementwise multiplication operation.
class Multiply : public util::BinaryElementwiseArithmetic
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Multiply", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a multiplication operation.
Multiply() = default;
/// \brief Constructs a multiplication operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Multiply(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
} // namespace v0
namespace v1
{
/// \brief Elementwise multiplication operation.
class Multiply : public util::BinaryElementwiseArithmetic
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Multiply", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a multiplication operation.
Multiply() = default;
/// \brief Constructs a multiplication operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Multiply(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
size_t get_version() const override { return 1; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
} // namespace v1
using v0::Multiply;
} // namespace op
std::shared_ptr<Node> operator*(const Output<Node>& arg0, const Output<Node>& arg1);
}
} // namespace ngraph
......@@ -19,18 +19,38 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::NotEqual::type_info;
// ----------------------------------- v0 --------------------------------------
op::NotEqual::NotEqual(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
constexpr NodeTypeInfo op::v0::NotEqual::type_info;
op::v0::NotEqual::NotEqual(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseComparison(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v0::NotEqual::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::v0::NotEqual>(new_args.at(0), new_args.at(1), this->get_autob());
}
// ----------------------------------- v1 --------------------------------------
constexpr NodeTypeInfo op::v1::NotEqual::type_info;
op::v1::NotEqual::NotEqual(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseComparison(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::NotEqual::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v1::NotEqual::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<NotEqual>(new_args.at(0), new_args.at(1), this->get_autob());
return make_shared<op::v1::NotEqual>(new_args.at(0), new_args.at(1), this->get_autob());
}
......@@ -22,28 +22,62 @@ namespace ngraph
{
namespace op
{
/// \brief Elementwise not-equal operation.
class NotEqual : public util::BinaryElementwiseComparison
namespace v0
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"NotEqual", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a not-equal operation.
NotEqual() = default;
/// \brief Constructs a not-equal operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
NotEqual(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
};
/// \brief Elementwise not-equal operation.
class NotEqual : public util::BinaryElementwiseComparison
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"NotEqual", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a not-equal operation.
NotEqual() = default;
/// \brief Constructs a not-equal operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
NotEqual(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
};
} // namespace v0
namespace v1
{
/// \brief Elementwise not-equal operation.
class NotEqual : public util::BinaryElementwiseComparison
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"NotEqual", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a not-equal operation.
NotEqual() = default;
/// \brief Constructs a not-equal operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
NotEqual(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
size_t get_version() const override { return 1; }
};
} // namespace v1
using v0::NotEqual;
}
}
......@@ -22,23 +22,61 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::Power::type_info;
// ------------------------------ v0 -------------------------------------------
op::Power::Power(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
constexpr NodeTypeInfo op::v0::Power::type_info;
op::v0::Power::Power(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v0::Power::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::v0::Power>(new_args.at(0), new_args.at(1), this->get_autob());
}
void op::v0::Power::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
if (get_autob().m_type != op::AutoBroadcastType::NONE)
{
throw ngraph_error("Autodiff not supported with auto broadcasting");
}
auto delta = deltas.at(0);
auto x = input_value(0);
auto y = input_value(1);
auto log_x = make_shared<op::Log>(x);
adjoints.add_delta(x, delta * y * shared_from_this() / x);
adjoints.add_delta(y, delta * shared_from_this() * log_x);
}
// ------------------------------ v1 -------------------------------------------
constexpr NodeTypeInfo op::v1::Power::type_info;
op::v1::Power::Power(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseArithmetic(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::Power::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v1::Power::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Power>(new_args.at(0), new_args.at(1), this->get_autob());
return make_shared<op::v1::Power>(new_args.at(0), new_args.at(1), this->get_autob());
}
void op::Power::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
void op::v1::Power::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
{
if (get_autob().m_type != op::AutoBroadcastType::NONE)
{
......
......@@ -22,44 +22,93 @@ namespace ngraph
{
namespace op
{
// clang-format off
/// \brief Elementwise exponentiation operation.
///
/// ## Inputs
///
/// | | Type | Description |
/// | ------ | --------------------------------- | ------------------------------------------------------ |
/// | `arg0` | \f$N[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape and numeric element type. |
/// | `arg1` | \f$N[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of the same shape and element type as `arg0`. |
///
/// ## Output
///
/// | Type | Description |
/// | ---------------------- | -------------------------------------------------------------------------------------------------------------- |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg0}[i_1,\dots,i_n]^{\texttt{arg1}[i_1,\dots,i_n]}\f$ |
// clang-format on
class Power : public util::BinaryElementwiseArithmetic
namespace v0
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Power", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
Power() = default;
/// \brief Constructs an exponentiation operation.
// clang-format off
/// \brief Elementwise exponentiation operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Power(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
/// ## Inputs
///
/// | | Type | Description |
/// | ------ | --------------------------------- | ------------------------------------------------------ |
/// | `arg0` | \f$N[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape and numeric element type. |
/// | `arg1` | \f$N[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of the same shape and element type as `arg0`. |
///
/// ## Output
///
/// | Type | Description |
/// | ---------------------- | -------------------------------------------------------------------------------------------------------------- |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg0}[i_1,\dots,i_n]^{\texttt{arg1}[i_1,\dots,i_n]}\f$ |
// clang-format on
class Power : public util::BinaryElementwiseArithmetic
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Power", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
Power() = default;
/// \brief Constructs an exponentiation operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Power(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
} // namespace v0
namespace v1
{
// clang-format off
/// \brief Elementwise exponentiation operation.
///
/// ## Inputs
///
/// | | Type | Description |
/// | ------ | --------------------------------- | ------------------------------------------------------ |
/// | `arg0` | \f$N[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of any shape and numeric element type. |
/// | `arg1` | \f$N[d_1,\dots,d_n]~(n \geq 0)\f$ | A tensor of the same shape and element type as `arg0`. |
///
/// ## Output
///
/// | Type | Description |
/// | ---------------------- | -------------------------------------------------------------------------------------------------------------- |
/// | \f$N[d_1,\dots,d_n]\f$ | The tensor \f$T\f$, where \f$T[i_1,\dots,i_n] = \texttt{arg0}[i_1,\dots,i_n]^{\texttt{arg1}[i_1,\dots,i_n]}\f$ |
// clang-format on
class Power : public util::BinaryElementwiseArithmetic
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Power", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
Power() = default;
/// \brief Constructs an exponentiation operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
Power(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
size_t get_version() const override { return 1; }
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
} // namespace v1
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
using v0::Power;
}
}
......@@ -18,19 +18,30 @@
#include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/equal.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/less_eq.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/not.hpp"
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/or.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/power.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/reduce_prod.hpp"
#include "ngraph/op/reduce_sum.hpp"
......@@ -79,6 +90,17 @@ static OP_TYPEID get_typeid(shared_ptr<Node> node)
}
// END mapping to OP_TYPEID
template <typename OpV0, typename OpV1>
void downgrade_binary_elementwise_node(const shared_ptr<Node>& node)
{
const auto tmp = as_type_ptr<OpV1>(node);
const auto input_arg0 = node->input(0).get_source_output();
const auto input_arg1 = node->input(1).get_source_output();
const auto autob = tmp->get_autob();
auto replacement_node = make_shared<OpV0>(input_arg0, input_arg1, autob);
replace_node(node, replacement_node);
}
bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
{
bool modified = false;
......@@ -104,6 +126,12 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
#endif
switch (get_typeid(node))
{
case OP_TYPEID::Add:
{
downgrade_binary_elementwise_node<op::v0::Add, op::v1::Add>(node);
modified = true;
break;
}
case OP_TYPEID::AvgPool:
{
const auto tmp = as_type_ptr<op::v1::AvgPool>(node);
......@@ -250,6 +278,18 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::Divide:
{
const auto tmp = as_type_ptr<op::v1::Divide>(node);
const auto input_arg0 = node->input(0).get_source_output();
const auto input_arg1 = node->input(1).get_source_output();
const auto autob = tmp->get_autob();
const bool pydiv = tmp->is_pythondiv();
auto replacement_node = make_shared<op::v0::Divide>(input_arg0, input_arg1, pydiv, autob);
replace_node(node, replacement_node);
modified = true;
break;
}
case OP_TYPEID::DynReshape:
{
auto tmp = as_type_ptr<op::v1::Reshape>(node);
......@@ -260,6 +300,12 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::Equal:
{
downgrade_binary_elementwise_node<op::v0::Equal, op::v1::Equal>(node);
modified = true;
break;
}
case OP_TYPEID::GenerateMask:
{
auto tmp = dynamic_cast<const op::v1::GenerateMask*>(node.get());
......@@ -279,23 +325,33 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::Greater:
{
downgrade_binary_elementwise_node<op::v0::Greater, op::v1::Greater>(node);
modified = true;
break;
}
case OP_TYPEID::GreaterEq:
{
downgrade_binary_elementwise_node<op::v0::GreaterEq, op::v1::GreaterEq>(node);
modified = true;
break;
}
case OP_TYPEID::Less:
{
downgrade_binary_elementwise_node<op::v0::Less, op::v1::Less>(node);
modified = true;
break;
}
case OP_TYPEID::LessEqual:
{
auto less_eq_v1 = as_type_ptr<op::v1::LessEqual>(node);
auto replacement_node = make_shared<op::v0::LessEq>(node->input(0).get_source_output(),
node->input(1).get_source_output(),
less_eq_v1->get_autob());
replace_node(node, replacement_node);
downgrade_binary_elementwise_node<op::v0::LessEq, op::v1::LessEqual>(node);
modified = true;
break;
}
case OP_TYPEID::LogicalAnd:
{
auto and_v1 = as_type_ptr<op::v1::LogicalAnd>(node);
auto replacement_node = make_shared<op::v0::And>(node->input(0).get_source_output(),
node->input(1).get_source_output(),
and_v1->get_autob());
replace_node(node, replacement_node);
downgrade_binary_elementwise_node<op::v0::And, op::v1::LogicalAnd>(node);
modified = true;
break;
}
......@@ -307,21 +363,19 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
}
case OP_TYPEID::LogicalOr:
{
auto or_v1 = as_type_ptr<op::v1::LogicalOr>(node);
auto replacement_node = make_shared<op::v0::Or>(node->input(0).get_source_output(),
node->input(1).get_source_output(),
or_v1->get_autob());
replace_node(node, replacement_node);
downgrade_binary_elementwise_node<op::v0::Or, op::v1::LogicalOr>(node);
modified = true;
break;
}
case OP_TYPEID::LogicalXor:
{
auto xor_v1 = as_type_ptr<op::v1::LogicalXor>(node);
auto replacement_node = make_shared<op::v0::Xor>(node->input(0).get_source_output(),
node->input(1).get_source_output(),
xor_v1->get_autob());
replace_node(node, replacement_node);
downgrade_binary_elementwise_node<op::v0::Xor, op::v1::LogicalXor>(node);
modified = true;
break;
}
case OP_TYPEID::Maximum:
{
downgrade_binary_elementwise_node<op::v0::Maximum, op::v1::Maximum>(node);
modified = true;
break;
}
......@@ -385,6 +439,24 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::Minimum:
{
downgrade_binary_elementwise_node<op::v0::Minimum, op::v1::Minimum>(node);
modified = true;
break;
}
case OP_TYPEID::Multiply:
{
downgrade_binary_elementwise_node<op::v0::Multiply, op::v1::Multiply>(node);
modified = true;
break;
}
case OP_TYPEID::NotEqual:
{
downgrade_binary_elementwise_node<op::v0::NotEqual, op::v1::NotEqual>(node);
modified = true;
break;
}
case OP_TYPEID::Pad:
{
auto tmp = as_type_ptr<op::v1::Pad>(node);
......@@ -397,6 +469,12 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::Power:
{
downgrade_binary_elementwise_node<op::v0::Power, op::v1::Power>(node);
modified = true;
break;
}
case OP_TYPEID::Product:
{
auto tmp = as_type_ptr<op::v1::ReduceProd>(node);
......
......@@ -15,19 +15,30 @@
//*****************************************************************************
#include "ngraph/pass/opset1_upgrade.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/equal.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/less_eq.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/not.hpp"
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/or.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/power.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/reduce_prod.hpp"
#include "ngraph/op/reduce_sum.hpp"
......@@ -77,6 +88,16 @@ static OP_TYPEID get_typeid(shared_ptr<Node> node)
}
// END mapping to OP_TYPEID
template <typename OpV0, typename OpV1>
void upgrade_binary_elementwise_node(const shared_ptr<Node>& node)
{
const auto tmp = dynamic_cast<const OpV0*>(node.get());
const auto autob = tmp->get_autob();
auto replacement_node = make_shared<OpV1>(
node->input(0).get_source_output(), node->input(1).get_source_output(), autob);
replace_node(node, replacement_node);
}
bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
{
bool modified = false;
......@@ -102,13 +123,15 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
#endif
switch (get_typeid(node))
{
case OP_TYPEID::Add:
{
upgrade_binary_elementwise_node<op::v0::Add, op::v1::Add>(node);
modified = true;
break;
}
case OP_TYPEID::And:
{
const auto and_v0 = dynamic_cast<const op::v0::And*>(node.get());
auto replacement_node = make_shared<op::v1::LogicalAnd>(node->input(0).get_source_output(),
node->input(1).get_source_output(),
and_v0->get_autob());
replace_node(node, replacement_node);
upgrade_binary_elementwise_node<op::v0::And, op::v1::LogicalAnd>(node);
modified = true;
break;
}
......@@ -284,6 +307,17 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::Divide:
{
const auto tmp = dynamic_cast<const op::v0::Divide*>(node.get());
const auto autob = tmp->get_autob();
const bool pydiv = tmp->is_pythondiv();
auto replacement_node = make_shared<op::v1::Divide>(
node->input(0).get_source_output(), node->input(1).get_source_output(), pydiv, autob);
replace_node(node, replacement_node);
modified = true;
break;
}
case OP_TYPEID::DynReshape:
{
auto zero_flag = false;
......@@ -293,6 +327,12 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::Equal:
{
upgrade_binary_elementwise_node<op::v0::Equal, op::v1::Equal>(node);
modified = true;
break;
}
case OP_TYPEID::Gather:
{
auto tmp = dynamic_cast<const op::v0::Gather*>(node.get());
......@@ -305,13 +345,33 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::Greater:
{
upgrade_binary_elementwise_node<op::v0::Greater, op::v1::Greater>(node);
modified = true;
break;
}
case OP_TYPEID::GreaterEq:
{
upgrade_binary_elementwise_node<op::v0::GreaterEq, op::v1::GreaterEq>(node);
modified = true;
break;
}
case OP_TYPEID::Less:
{
upgrade_binary_elementwise_node<op::v0::Less, op::v1::Less>(node);
modified = true;
break;
}
case OP_TYPEID::LessEq:
{
const auto less_eq_v0 = dynamic_cast<const op::v0::LessEq*>(node.get());
auto replacement_node = make_shared<op::v1::LessEqual>(node->input(0).get_source_output(),
node->input(1).get_source_output(),
less_eq_v0->get_autob());
replace_node(node, replacement_node);
upgrade_binary_elementwise_node<op::v0::LessEq, op::v1::LessEqual>(node);
modified = true;
break;
}
case OP_TYPEID::Maximum:
{
upgrade_binary_elementwise_node<op::v0::Maximum, op::v1::Maximum>(node);
modified = true;
break;
}
......@@ -372,19 +432,33 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::Minimum:
{
upgrade_binary_elementwise_node<op::v0::Minimum, op::v1::Minimum>(node);
modified = true;
break;
}
case OP_TYPEID::Multiply:
{
upgrade_binary_elementwise_node<op::v0::Multiply, op::v1::Multiply>(node);
modified = true;
break;
}
case OP_TYPEID::Not:
{
replace_node(node, make_shared<op::v1::LogicalNot>(node->input(0).get_source_output()));
modified = true;
break;
}
case OP_TYPEID::NotEqual:
{
upgrade_binary_elementwise_node<op::v0::NotEqual, op::v1::NotEqual>(node);
modified = true;
break;
}
case OP_TYPEID::Or:
{
const auto or_v0 = dynamic_cast<const op::v0::Or*>(node.get());
auto replacement_node = make_shared<op::v1::LogicalOr>(node->input(0).get_source_output(),
node->input(1).get_source_output(),
or_v0->get_autob());
replace_node(node, replacement_node);
upgrade_binary_elementwise_node<op::v0::Or, op::v1::LogicalOr>(node);
modified = true;
break;
}
......@@ -408,6 +482,12 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::Power:
{
upgrade_binary_elementwise_node<op::v0::Power, op::v1::Power>(node);
modified = true;
break;
}
case OP_TYPEID::Product:
{
bool keep_dims = false;
......
......@@ -21,19 +21,30 @@
#include "ngraph/code_writer.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/equal.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/less_eq.hpp"
#include "ngraph/op/max.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/min.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/not.hpp"
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/or.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/power.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/slice.hpp"
......@@ -55,7 +66,6 @@ namespace ngraph
{
namespace op
{
class Add;
class AllReduce;
class BroadcastDistributed;
class MatmulBias;
......@@ -69,23 +79,14 @@ namespace ngraph
class BatchNormInferenceRelu;
class BatchNormTrainingBackprop;
class Dot;
class Multiply;
class GetOutputElement;
class Abs;
class Concat;
class Divide;
class Equal;
class Greater;
class GreaterEq;
class Less;
class Any;
class All;
class LRN;
class Log;
class Maximum;
class Minimum;
class Negative;
class NotEqual;
class Select;
class Subtract;
class Convert;
......@@ -107,7 +108,6 @@ namespace ngraph
class GatherND;
class ScatterAdd;
class ScatterNDAdd;
class Power;
class UpdateSlice;
class ReplaceSlice;
class OneHot;
......
......@@ -346,7 +346,10 @@ static json write_auto_broadcast(const op::AutoBroadcastSpec& autob)
return j;
}
static op::AutoBroadcastSpec read_auto_broadcast(json js_node, const std::string& attr)
static op::AutoBroadcastSpec
read_auto_broadcast(json js_node,
const std::string& attr,
const op::AutoBroadcastSpec& autob = op::AutoBroadcastSpec())
{
if (has_key(js_node, attr))
{
......@@ -356,7 +359,7 @@ static op::AutoBroadcastSpec read_auto_broadcast(json js_node, const std::string
}
else
{
return op::AutoBroadcastSpec();
return autob;
}
}
......@@ -767,9 +770,20 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
}
case OP_TYPEID::Add:
{
node = make_shared<op::Add>(
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
break;
if (op_version == 0)
{
node = make_shared<op::v0::Add>(
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
break;
}
if (op_version == 1)
{
node = make_shared<op::v1::Add>(
args[0],
args[1],
read_auto_broadcast(node_js, "auto_broadcast", op::AutoBroadcastType::NUMPY));
break;
}
}
case OP_TYPEID::All:
{
......@@ -1229,8 +1243,19 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
case OP_TYPEID::Divide:
{
bool pythondiv = get_or_default(node_js, "pythondiv", true);
node = make_shared<op::Divide>(
args[0], args[1], pythondiv, read_auto_broadcast(node_js, "auto_broadcast"));
if (op_version == 0)
{
node = make_shared<op::v0::Divide>(
args[0], args[1], pythondiv, read_auto_broadcast(node_js, "auto_broadcast"));
}
if (op_version == 1)
{
node = make_shared<op::v1::Divide>(
args[0],
args[1],
pythondiv,
read_auto_broadcast(node_js, "auto_broadcast", op::AutoBroadcastType::NUMPY));
}
break;
}
case OP_TYPEID::Dot:
......@@ -1320,8 +1345,18 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
}
case OP_TYPEID::Equal:
{
node = make_shared<op::Equal>(
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
if (op_version == 0)
{
node = make_shared<op::v0::Equal>(
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
}
if (op_version == 1)
{
node = make_shared<op::v1::Equal>(
args[0],
args[1],
read_auto_broadcast(node_js, "auto_broadcast", op::AutoBroadcastType::NUMPY));
}
break;
}
case OP_TYPEID::Erf:
......@@ -1414,14 +1449,34 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
}
case OP_TYPEID::Greater:
{
node = make_shared<op::Greater>(
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
if (op_version == 0)
{
node = make_shared<op::v0::Greater>(
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
}
if (op_version == 1)
{
node = make_shared<op::v1::Greater>(
args[0],
args[1],
read_auto_broadcast(node_js, "auto_broadcast", op::AutoBroadcastType::NUMPY));
}
break;
}
case OP_TYPEID::GreaterEq:
{
node = make_shared<op::GreaterEq>(
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
if (op_version == 0)
{
node = make_shared<op::v0::GreaterEq>(
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
}
if (op_version == 1)
{
node = make_shared<op::v1::GreaterEq>(
args[0],
args[1],
read_auto_broadcast(node_js, "auto_broadcast", op::AutoBroadcastType::NUMPY));
}
break;
}
case OP_TYPEID::GRN:
......@@ -1551,8 +1606,18 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
}
case OP_TYPEID::Less:
{
node = make_shared<op::Less>(
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
if (op_version == 0)
{
node = make_shared<op::v0::Less>(
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
}
else if (op_version == 1)
{
node = make_shared<op::v1::Less>(
args[0],
args[1],
read_auto_broadcast(node_js, "auto_broadcast", op::AutoBroadcastType::NUMPY));
}
break;
}
case OP_TYPEID::LessEq:
......@@ -1564,7 +1629,9 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
case OP_TYPEID::LessEqual:
{
node = make_shared<op::v1::LessEqual>(
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
args[0],
args[1],
read_auto_broadcast(node_js, "auto_broadcast", op::AutoBroadcastType::NUMPY));
break;
}
case OP_TYPEID::Log:
......@@ -1788,8 +1855,18 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
}
case OP_TYPEID::Maximum:
{
node = make_shared<op::Maximum>(
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
if (op_version == 0)
{
node = make_shared<op::v0::Maximum>(
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
}
else if (op_version == 1)
{
node = make_shared<op::v1::Maximum>(
args[0],
args[1],
read_auto_broadcast(node_js, "auto_broadcast", op::AutoBroadcastType::NUMPY));
}
break;
}
case OP_TYPEID::Min:
......@@ -1800,14 +1877,34 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
}
case OP_TYPEID::Minimum:
{
node = make_shared<op::Minimum>(
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
if (op_version == 0)
{
node = make_shared<op::v0::Minimum>(
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
}
else if (op_version == 1)
{
node = make_shared<op::v1::Minimum>(
args[0],
args[1],
read_auto_broadcast(node_js, "auto_broadcast", op::AutoBroadcastType::NUMPY));
}
break;
}
case OP_TYPEID::Multiply:
{
node = make_shared<op::Multiply>(
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
if (op_version == 0)
{
node = make_shared<op::v0::Multiply>(
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
}
else if (op_version == 1)
{
node = make_shared<op::v1::Multiply>(
args[0],
args[1],
read_auto_broadcast(node_js, "auto_broadcast", op::AutoBroadcastType::NUMPY));
}
break;
}
case OP_TYPEID::MVN:
......@@ -1832,8 +1929,18 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
}
case OP_TYPEID::NotEqual:
{
node = make_shared<op::NotEqual>(
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
if (op_version == 0)
{
node = make_shared<op::v0::NotEqual>(
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
}
else if (op_version == 1)
{
node = make_shared<op::v1::NotEqual>(
args[0],
args[1],
read_auto_broadcast(node_js, "auto_broadcast", op::AutoBroadcastType::NUMPY));
}
break;
}
case OP_TYPEID::Not:
......@@ -1936,8 +2043,18 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
}
case OP_TYPEID::Power:
{
node = make_shared<op::Power>(
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
if (op_version == 0)
{
node = make_shared<op::v0::Power>(
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
}
else if (op_version == 1)
{
node = make_shared<op::v1::Power>(
args[0],
args[1],
read_auto_broadcast(node_js, "auto_broadcast", op::AutoBroadcastType::NUMPY));
}
break;
}
case OP_TYPEID::PRelu:
......@@ -2575,8 +2692,16 @@ json JSONSerializer::serialize_node(const Node& n)
}
case OP_TYPEID::Add:
{
auto tmp = static_cast<const op::Add*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
const op::util::BinaryElementwiseArithmetic* tmp = nullptr;
if (op_version == 0)
{
tmp = static_cast<const op::v0::Add*>(&n);
}
if (op_version == 1)
{
tmp = static_cast<const op::v1::Add*>(&n);
}
if (tmp != nullptr && tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["auto_broadcast"] = write_auto_broadcast(tmp->get_autob());
}
......@@ -2896,11 +3021,22 @@ json JSONSerializer::serialize_node(const Node& n)
}
case OP_TYPEID::Divide:
{
auto tmp = static_cast<const op::Divide*>(&n);
node["pythondiv"] = tmp->is_pythondiv();
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
const op::util::BinaryElementwiseArithmetic* bea_node = nullptr;
if (op_version == 0)
{
node["auto_broadcast"] = write_auto_broadcast(tmp->get_autob());
auto tmp = static_cast<const op::v0::Divide*>(&n);
bea_node = tmp;
node["pythondiv"] = tmp->is_pythondiv();
}
else if (op_version == 1)
{
auto tmp = static_cast<const op::v1::Divide*>(&n);
bea_node = tmp;
node["pythondiv"] = tmp->is_pythondiv();
}
if (bea_node != nullptr && bea_node->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["auto_broadcast"] = write_auto_broadcast(bea_node->get_autob());
}
break;
}
......@@ -2958,8 +3094,16 @@ json JSONSerializer::serialize_node(const Node& n)
}
case OP_TYPEID::Equal:
{
auto tmp = static_cast<const op::Equal*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
const op::util::BinaryElementwiseComparison* tmp = nullptr;
if (op_version == 0)
{
tmp = static_cast<const op::v0::Equal*>(&n);
}
if (op_version == 1)
{
tmp = static_cast<const op::v1::Equal*>(&n);
}
if (tmp != nullptr && tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["auto_broadcast"] = write_auto_broadcast(tmp->get_autob());
}
......@@ -3023,8 +3167,16 @@ json JSONSerializer::serialize_node(const Node& n)
}
case OP_TYPEID::Greater:
{
auto tmp = static_cast<const op::Greater*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
const op::util::BinaryElementwiseComparison* tmp = nullptr;
if (op_version == 0)
{
tmp = static_cast<const op::v0::Greater*>(&n);
}
else if (op_version == 1)
{
tmp = static_cast<const op::v1::Greater*>(&n);
}
if (tmp != nullptr && tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["auto_broadcast"] = write_auto_broadcast(tmp->get_autob());
}
......@@ -3032,8 +3184,16 @@ json JSONSerializer::serialize_node(const Node& n)
}
case OP_TYPEID::GreaterEq:
{
auto tmp = static_cast<const op::GreaterEq*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
const op::util::BinaryElementwiseComparison* tmp = nullptr;
if (op_version == 0)
{
tmp = static_cast<const op::v0::GreaterEq*>(&n);
}
else if (op_version == 1)
{
tmp = static_cast<const op::v1::GreaterEq*>(&n);
}
if (tmp != nullptr && tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["auto_broadcast"] = write_auto_broadcast(tmp->get_autob());
}
......@@ -3108,8 +3268,16 @@ json JSONSerializer::serialize_node(const Node& n)
}
case OP_TYPEID::Less:
{
auto tmp = static_cast<const op::Less*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
const op::util::BinaryElementwiseComparison* tmp = nullptr;
if (op_version == 0)
{
tmp = static_cast<const op::v0::Less*>(&n);
}
else if (op_version == 1)
{
tmp = static_cast<const op::v1::Less*>(&n);
}
if (tmp != nullptr && tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["auto_broadcast"] = write_auto_broadcast(tmp->get_autob());
}
......@@ -3254,8 +3422,16 @@ json JSONSerializer::serialize_node(const Node& n)
}
case OP_TYPEID::Maximum:
{
auto tmp = static_cast<const op::Maximum*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
const op::util::BinaryElementwiseArithmetic* tmp = nullptr;
if (op_version == 0)
{
tmp = static_cast<const op::v0::Maximum*>(&n);
}
else if (op_version == 1)
{
tmp = static_cast<const op::v1::Maximum*>(&n);
}
if (tmp != nullptr && tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["auto_broadcast"] = write_auto_broadcast(tmp->get_autob());
}
......@@ -3269,8 +3445,16 @@ json JSONSerializer::serialize_node(const Node& n)
}
case OP_TYPEID::Minimum:
{
auto tmp = static_cast<const op::Minimum*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
const op::util::BinaryElementwiseArithmetic* tmp = nullptr;
if (op_version == 0)
{
tmp = static_cast<const op::v0::Minimum*>(&n);
}
else if (op_version == 1)
{
tmp = static_cast<const op::v1::Minimum*>(&n);
}
if (tmp != nullptr && tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["auto_broadcast"] = write_auto_broadcast(tmp->get_autob());
}
......@@ -3278,8 +3462,16 @@ json JSONSerializer::serialize_node(const Node& n)
}
case OP_TYPEID::Multiply:
{
auto tmp = static_cast<const op::Multiply*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
const op::util::BinaryElementwiseArithmetic* tmp = nullptr;
if (op_version == 0)
{
tmp = static_cast<const op::v0::Multiply*>(&n);
}
else if (op_version == 1)
{
tmp = static_cast<const op::v1::Multiply*>(&n);
}
if (tmp != nullptr && tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["auto_broadcast"] = write_auto_broadcast(tmp->get_autob());
}
......@@ -3304,8 +3496,16 @@ json JSONSerializer::serialize_node(const Node& n)
}
case OP_TYPEID::NotEqual:
{
auto tmp = static_cast<const op::NotEqual*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
const op::util::BinaryElementwiseComparison* tmp = nullptr;
if (op_version == 0)
{
tmp = static_cast<const op::v0::NotEqual*>(&n);
}
else if (op_version == 1)
{
tmp = static_cast<const op::v1::NotEqual*>(&n);
}
if (tmp != nullptr && tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["auto_broadcast"] = write_auto_broadcast(tmp->get_autob());
}
......@@ -3404,8 +3604,16 @@ json JSONSerializer::serialize_node(const Node& n)
}
case OP_TYPEID::Power:
{
auto tmp = static_cast<const op::Power*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
const op::util::BinaryElementwiseArithmetic* tmp = nullptr;
if (op_version == 0)
{
tmp = static_cast<const op::v0::Power*>(&n);
}
else if (op_version == 1)
{
tmp = static_cast<const op::v1::Power*>(&n);
}
if (tmp != nullptr && tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["auto_broadcast"] = write_auto_broadcast(tmp->get_autob());
}
......
......@@ -69,13 +69,13 @@ set(SRC
node_input_output.cpp
nop_elimination.cpp
op.cpp
opset_pass/binary_elementwise_opset_pass.cpp
opset_pass/broadcast_opset_pass.cpp
opset_pass/convolution_opset_pass.cpp
opset_pass/dyn_reshape_opset_pass.cpp
opset_pass/logical_and_opset_pass.cpp
opset_pass/logical_not_opset_pass.cpp
opset_pass/logical_or_opset_pass.cpp
opset_pass/logical_less_equal_opset_pass.cpp
opset_pass/logical_xor_opset_pass.cpp
opset_pass/gather_opset_pass.cpp
opset_pass/generate_mask_opset_pass.cpp
......
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/pass/opset1_upgrade.hpp"
#include "util/test_control.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
//------------------------------------------------------------------------------
//
// Helper Functions
//
//------------------------------------------------------------------------------
template <typename OpV0, typename OpV1>
void test_type_prop_opset0_downgrade_pass(const element::Type& output_type,
const element::Type& input_type = element::f32,
const string node_name = "")
{
auto A = make_shared<op::Parameter>(input_type, Shape{1, 3, 2});
auto B = make_shared<op::Parameter>(input_type, Shape{1, 2});
const op::AutoBroadcastSpec np_auto_b = op::AutoBroadcastSpec(op::AutoBroadcastType::NUMPY);
auto v1_node = make_shared<OpV1>(A, B);
auto result = make_shared<op::Result>(v1_node);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{A, B});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
auto v0_result = f->get_results().at(0);
auto node = v0_result->input(0).get_source_output().get_node_shared_ptr();
auto v0_node = static_pointer_cast<OpV0>(node);
EXPECT_EQ(v0_node->description(), (node_name.empty() ? v1_node->description() : node_name));
EXPECT_EQ(v0_node->get_version(), 0);
EXPECT_EQ(v0_node->get_autob(), np_auto_b);
EXPECT_EQ(v0_node->output(0).get_element_type(), output_type);
EXPECT_EQ(v0_node->output(0).get_shape(), (Shape{1, 3, 2}));
}
template <typename OpV0, typename OpV1>
void test_opset0_arithmetic_downgrade_pass()
{
test_type_prop_opset0_downgrade_pass<OpV0, OpV1>(element::f32);
}
template <typename OpV0, typename OpV1>
void test_opset0_comparison_downgrade_pass()
{
test_type_prop_opset0_downgrade_pass<OpV0, OpV1>(element::boolean);
}
template <typename OpV0, typename OpV1>
void test_type_prop_opset1_upgrade_pass(const element::Type& output_type,
const element::Type& input_type = element::f32,
const string node_name = "")
{
auto A = make_shared<op::Parameter>(input_type, Shape{1, 3, 2});
auto B = make_shared<op::Parameter>(input_type, Shape{1, 3, 2});
const op::AutoBroadcastSpec none_auto_b = op::AutoBroadcastSpec(op::AutoBroadcastType::NONE);
auto v0_node = make_shared<OpV0>(A, B);
auto result = make_shared<op::Result>(v0_node);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{A, B});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset1Upgrade>();
pass_manager.run_passes(f);
auto v1_result = f->get_results().at(0);
auto node = v1_result->input(0).get_source_output().get_node_shared_ptr();
auto v1_node = static_pointer_cast<OpV1>(node);
EXPECT_EQ(v1_node->description(), (node_name.empty() ? v0_node->description() : node_name));
EXPECT_EQ(v1_node->get_version(), 1);
EXPECT_EQ(v1_node->get_autob(), none_auto_b);
EXPECT_EQ(v1_node->output(0).get_element_type(), output_type);
EXPECT_EQ(v1_node->output(0).get_shape(), (Shape{1, 3, 2}));
}
template <typename OpV0, typename OpV1>
void test_opset1_arithmetic_upgrade_pass()
{
test_type_prop_opset1_upgrade_pass<OpV0, OpV1>(element::f32);
}
template <typename OpV0, typename OpV1>
void test_opset1_comparison_upgrade_pass()
{
test_type_prop_opset1_upgrade_pass<OpV0, OpV1>(element::boolean);
}
//------------------------------------------------------------------------------
//
// Test Cases
//
//------------------------------------------------------------------------------
TEST(opset_transform, opset0_add_downgrade_pass)
{
test_opset0_arithmetic_downgrade_pass<op::v0::Add, op::v1::Add>();
}
TEST(opset_transform, opset1_add_upgrade_pass)
{
test_opset1_arithmetic_upgrade_pass<op::v0::Add, op::v1::Add>();
}
TEST(opset_transform, opset0_divide_downgrade_pass)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{1, 3, 2});
auto B = make_shared<op::Parameter>(element::f32, Shape{1, 2});
const op::AutoBroadcastSpec np_auto_b = op::AutoBroadcastSpec(op::AutoBroadcastType::NUMPY);
const bool pydiv = false;
auto divide_v1 = make_shared<op::v1::Divide>(A, B);
divide_v1->set_is_pythondiv(pydiv);
auto result = make_shared<op::Result>(divide_v1);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{A, B});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
auto divide_v0_result = f->get_results().at(0);
auto node = divide_v0_result->input(0).get_source_output().get_node_shared_ptr();
auto divide_v0_node = static_pointer_cast<op::v0::Divide>(node);
EXPECT_EQ(divide_v0_node->description(), "Divide");
EXPECT_EQ(divide_v0_node->get_version(), 0);
EXPECT_EQ(divide_v0_node->is_pythondiv(), pydiv);
EXPECT_EQ(divide_v0_node->get_autob(), np_auto_b);
EXPECT_EQ(divide_v0_node->output(0).get_element_type(), element::f32);
EXPECT_EQ(divide_v0_node->output(0).get_shape(), (Shape{1, 3, 2}));
}
TEST(opset_transform, opset1_divide_upgrade_pass)
{
auto A = make_shared<op::Parameter>(element::f32, Shape{1, 3, 2});
auto B = make_shared<op::Parameter>(element::f32, Shape{1, 3, 2});
const op::AutoBroadcastSpec none_auto_b = op::AutoBroadcastSpec(op::AutoBroadcastType::NONE);
const bool pydiv = false;
auto div_v0 = make_shared<op::v0::Divide>(A, B, pydiv);
auto result = make_shared<op::Result>(div_v0);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{A, B});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset1Upgrade>();
pass_manager.run_passes(f);
auto divide_v1_result = f->get_results().at(0);
auto node = divide_v1_result->input(0).get_source_output().get_node_shared_ptr();
auto divide_v1_node = static_pointer_cast<op::v1::Divide>(node);
EXPECT_EQ(divide_v1_node->description(), "Divide");
EXPECT_EQ(divide_v1_node->get_version(), 1);
EXPECT_EQ(divide_v1_node->is_pythondiv(), pydiv);
EXPECT_EQ(divide_v1_node->get_autob(), none_auto_b);
EXPECT_EQ(divide_v1_node->output(0).get_element_type(), element::f32);
EXPECT_EQ(divide_v1_node->output(0).get_shape(), (Shape{1, 3, 2}));
}
TEST(opset_transform, opset0_equal_downgrade_pass)
{
test_opset0_comparison_downgrade_pass<op::v0::Equal, op::v1::Equal>();
}
TEST(opset_transform, opset1_equal_upgrade_pass)
{
test_opset1_comparison_upgrade_pass<op::v0::Equal, op::v1::Equal>();
}
TEST(opset_transform, opset0_greater_downgrade_pass)
{
test_opset0_comparison_downgrade_pass<op::v0::Greater, op::v1::Greater>();
}
TEST(opset_transform, opset1_greater_upgrade_pass)
{
test_opset1_comparison_upgrade_pass<op::v0::Greater, op::v1::Greater>();
}
TEST(opset_transform, opset0_greater_eq_downgrade_pass)
{
test_opset0_comparison_downgrade_pass<op::v0::GreaterEq, op::v1::GreaterEq>();
}
TEST(opset_transform, opset1_greater_eq_upgrade_pass)
{
test_opset1_comparison_upgrade_pass<op::v0::GreaterEq, op::v1::GreaterEq>();
}
TEST(opset_transform, opset0_less_downgrade_pass)
{
test_opset0_comparison_downgrade_pass<op::v0::Less, op::v1::Less>();
}
TEST(opset_transform, opset1_less_upgrade_pass)
{
test_opset1_comparison_upgrade_pass<op::v0::Less, op::v1::Less>();
}
TEST(opset_transform, opset0_less_eq_downgrade_pass)
{
test_type_prop_opset0_downgrade_pass<op::v0::LessEq, op::v1::LessEqual>(
element::boolean, element::f32, "LessEq");
}
TEST(opset_transform, opset1_less_eq_upgrade_pass)
{
test_type_prop_opset1_upgrade_pass<op::v0::LessEq, op::v1::LessEqual>(
element::boolean, element::f32, "LessEqual");
}
TEST(opset_transform, opset0_maximum_downgrade_pass)
{
test_opset0_arithmetic_downgrade_pass<op::v0::Maximum, op::v1::Maximum>();
}
TEST(opset_transform, opset1_maximum_upgrade_pass)
{
test_opset1_arithmetic_upgrade_pass<op::v0::Maximum, op::v1::Maximum>();
}
TEST(opset_transform, opset0_minimum_downgrade_pass)
{
test_opset0_arithmetic_downgrade_pass<op::v0::Minimum, op::v1::Minimum>();
}
TEST(opset_transform, opset1_minimum_upgrade_pass)
{
test_opset1_arithmetic_upgrade_pass<op::v0::Minimum, op::v1::Minimum>();
}
TEST(opset_transform, opset0_multiply_downgrade_pass)
{
test_opset0_arithmetic_downgrade_pass<op::v0::Multiply, op::v1::Multiply>();
}
TEST(opset_transform, opset1_multiply_upgrade_pass)
{
test_opset1_arithmetic_upgrade_pass<op::v0::Multiply, op::v1::Multiply>();
}
TEST(opset_transform, opset0_not_equal_downgrade_pass)
{
test_opset0_comparison_downgrade_pass<op::v0::NotEqual, op::v1::NotEqual>();
}
TEST(opset_transform, opset1_not_equal_upgrade_pass)
{
test_opset1_comparison_upgrade_pass<op::v0::NotEqual, op::v1::NotEqual>();
}
TEST(opset_transform, opset0_power_downgrade_pass)
{
test_opset0_arithmetic_downgrade_pass<op::v0::Power, op::v1::Power>();
}
TEST(opset_transform, opset1_power_upgrade_pass)
{
test_opset1_arithmetic_upgrade_pass<op::v0::Power, op::v1::Power>();
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/pass/opset1_upgrade.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(opset_transform, opset1_logical_less_equal_upgrade_pass)
{
const auto a = make_shared<op::Parameter>(element::boolean, Shape{5, 10, 15});
const auto b = make_shared<op::Parameter>(element::boolean, Shape{5, 10, 15});
const auto less_eq_v0 = make_shared<op::v0::LessEq>(a, b);
const auto result = make_shared<op::Result>(less_eq_v0);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{a, b});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset1Upgrade>();
pass_manager.run_passes(f);
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto less_eq_v1 = static_pointer_cast<op::v1::LessEqual>(pass_replacement_node);
EXPECT_EQ(less_eq_v1->description(), "LessEqual");
EXPECT_EQ(less_eq_v1->get_version(), 1);
const auto values_out_element_type = less_eq_v1->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean);
}
TEST(opset_transform, opset1_logical_less_equal_downgrade_pass)
{
const auto a = make_shared<op::Parameter>(element::boolean, Shape{5, 10, 15});
const auto b = make_shared<op::Parameter>(element::boolean, Shape{5, 10, 15});
const auto less_eq_v1 = make_shared<op::v1::LessEqual>(a, b);
const auto result = make_shared<op::Result>(less_eq_v1);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{a, b});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto less_eq_v0 = static_pointer_cast<op::v0::LessEq>(pass_replacement_node);
EXPECT_EQ(less_eq_v0->description(), "LessEq");
EXPECT_EQ(less_eq_v0->get_version(), 0);
const auto values_out_element_type = less_eq_v0->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment