Commit 18d12ad8 authored by Tomasz Socha's avatar Tomasz Socha Committed by Scott Cyphers

[SPEC] [MLIR] Adjust names of logical operators to the spec. (#3819)

* [SPEC] Rename operator Xor->LogicalXor

* Fix clang issue

* Fix bug in opset transformations. Add support for LogicalXor in backends

* Style fix

* Fix a bug in CPU emmiter

* [SPEC] Rename operator Or->LogicalOr

* [SPEC] Rename operator And->LogicalAnd

* [SPEC] Rename operator Not->LogicalNot

* [SPEC] Rename operator LessEq->LessEqual
parent be16a2fd
......@@ -19,7 +19,6 @@
#include "core/node.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/util/broadcasting.hpp"
namespace ngraph
{
......@@ -31,8 +30,8 @@ namespace ngraph
{
inline NodeVector logical_and(const Node& node)
{
NodeVector ng_inputs{ngraph::op::numpy_style_broadcast(node.get_ng_inputs())};
return {std::make_shared<ngraph::op::And>(ng_inputs.at(0), ng_inputs.at(1))};
return {std::make_shared<ngraph::op::v1::LogicalAnd>(
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1))};
}
} // namespace set_1
......
......@@ -30,7 +30,8 @@ namespace ngraph
{
inline NodeVector logical_not(const Node& node)
{
return {std::make_shared<ngraph::op::Not>(node.get_ng_inputs().at(0))};
return {
std::make_shared<ngraph::op::v1::LogicalNot>(node.get_ng_inputs().at(0))};
}
} // namespace set_1
......
......@@ -31,10 +31,8 @@ namespace ngraph
{
inline NodeVector logical_or(const Node& node)
{
return {std::make_shared<ngraph::op::Or>(
node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
return {std::make_shared<ngraph::op::v1::LogicalOr>(
node.get_ng_inputs().at(0), node.get_ng_inputs().at(1))};
}
} // namespace set_1
......
......@@ -31,7 +31,7 @@ namespace ngraph
{
inline NodeVector logical_xor(const Node& node)
{
return {std::make_shared<ngraph::op::Xor>(
return {std::make_shared<ngraph::op::v1::LogicalXor>(
node.get_ng_inputs().at(0),
node.get_ng_inputs().at(1),
ngraph::op::AutoBroadcastSpec(ngraph::op::AutoBroadcastType::NUMPY))};
......
......@@ -450,7 +450,7 @@ namespace ngraph
NodeVector get_users(bool check_is_used = false) const;
/// \return Version of this node
virtual size_t get_version() const { return 0; }
virtual size_t get_version() const { return get_type_info().version; }
virtual std::shared_ptr<Node> get_default_value() const { return nullptr; }
/// Use instance ids for comparison instead of memory addresses to improve determinism
bool operator<(const Node& other) const { return m_instance_id < other.m_instance_id; }
......
......@@ -19,18 +19,34 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::And::type_info;
constexpr NodeTypeInfo op::v1::LogicalAnd::type_info;
op::And::And(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
op::v1::LogicalAnd::LogicalAnd(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseLogical(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::And::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v1::LogicalAnd::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<And>(new_args.at(0), new_args.at(1), this->get_autob());
return make_shared<v1::LogicalAnd>(new_args.at(0), new_args.at(1), this->get_autob());
}
constexpr NodeTypeInfo op::v0::And::type_info;
op::v0::And::And(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseLogical(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v0::And::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<v0::And>(new_args.at(0), new_args.at(1), this->get_autob());
}
......@@ -24,34 +24,72 @@ namespace ngraph
{
namespace op
{
/// \brief Elementwise logical-and operation.
///
class And : public util::BinaryElementwiseLogical
namespace v1
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"And", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a logical-and operation.
And() = default;
/// \brief Constructs a logical-and operation.
/// \brief Elementwise logical-and operation.
///
/// \param arg0 Output that produces the first input tensor.<br>
/// `[d0, ...]`
/// \param arg1 Output that produces the second input tensor.<br>
/// `[d0, ...]`
/// \param auto_broadcast Auto broadcast specification
///
/// Output `[d0, ...]`
class LogicalAnd : public util::BinaryElementwiseLogical
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"LogicalAnd", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a logical-and operation.
LogicalAnd() = default;
/// \brief Constructs a logical-and operation.
///
/// \param arg0 Output that produces the first input tensor.<br>
/// `[d0, ...]`
/// \param arg1 Output that produces the second input tensor.<br>
/// `[d0, ...]`
/// \param auto_broadcast Auto broadcast specification
///
/// Output `[d0, ...]`
///
LogicalAnd(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
};
} // namespace v0
namespace v0
{
/// \brief Elementwise logical-and operation.
///
And(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
class And : public util::BinaryElementwiseLogical
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"And", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a logical-and operation.
And() = default;
/// \brief Constructs a logical-and operation.
///
/// \param arg0 Output that produces the first input tensor.<br>
/// `[d0, ...]`
/// \param arg1 Output that produces the second input tensor.<br>
/// `[d0, ...]`
/// \param auto_broadcast Auto broadcast specification
///
/// Output `[d0, ...]`
///
And(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
};
}
virtual bool is_commutative() const override { return true; }
};
using v0::And;
}
}
......@@ -19,18 +19,34 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::LessEq::type_info;
constexpr NodeTypeInfo op::v1::LessEqual::type_info;
op::LessEq::LessEq(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
op::v1::LessEqual::LessEqual(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseComparison(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::LessEq::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v1::LessEqual::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<LessEq>(new_args.at(0), new_args.at(1), this->get_autob());
return make_shared<v1::LessEqual>(new_args.at(0), new_args.at(1), this->get_autob());
}
constexpr NodeTypeInfo op::v0::LessEq::type_info;
op::v0::LessEq::LessEq(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseComparison(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v0::LessEq::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<v0::LessEq>(new_args.at(0), new_args.at(1), this->get_autob());
}
......@@ -22,26 +22,55 @@ namespace ngraph
{
namespace op
{
/// \brief Elementwise less-than-or-equal operation.
class LessEq : public util::BinaryElementwiseComparison
namespace v1
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"LessEq", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a less-than-or-equal operation.
LessEq() = default;
/// \brief Constructs a less-than-or-equal operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
LessEq(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
/// \brief Elementwise less-than-or-equal operation.
class LessEqual : public util::BinaryElementwiseComparison
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"LessEqual", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a less-than-or-equal operation.
LessEqual() = default;
/// \brief Constructs a less-than-or-equal operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
LessEqual(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
}
}
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
} // namespace v1
namespace v0
{
/// \brief Elementwise less-than-or-equal operation.
class LessEq : public util::BinaryElementwiseComparison
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"LessEq", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a less-than-or-equal operation.
LessEq() = default;
/// \brief Constructs a less-than-or-equal operation.
///
/// \param arg0 Node that produces the first input tensor.
/// \param arg1 Node that produces the second input tensor.
/// \param auto_broadcast Auto broadcast specification
LessEq(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
} // namespace v0
using v0::LessEq;
} // namespace op
} // namespace ngraph
......@@ -20,16 +20,16 @@
using namespace ngraph;
using namespace std;
constexpr NodeTypeInfo op::Not::type_info;
constexpr NodeTypeInfo op::v1::LogicalNot::type_info;
op::Not::Not(const Output<Node>& arg)
op::v1::LogicalNot::LogicalNot(const Output<Node>& arg)
: Op({arg})
{
constructor_validate_and_infer_types();
}
// TODO(amprocte): Update this to allow only boolean, for consistency with logical binops.
void op::Not::validate_and_infer_types()
void op::v1::LogicalNot::validate_and_infer_types()
{
auto args_et_pshape = validate_and_infer_elementwise_args();
element::Type& args_et = std::get<0>(args_et_pshape);
......@@ -38,8 +38,32 @@ void op::Not::validate_and_infer_types()
set_output_type(0, args_et, args_pshape);
}
shared_ptr<Node> op::Not::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v1::LogicalNot::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Not>(new_args.at(0));
return make_shared<v1::LogicalNot>(new_args.at(0));
}
constexpr NodeTypeInfo op::v0::Not::type_info;
op::v0::Not::Not(const Output<Node>& arg)
: Op({arg})
{
constructor_validate_and_infer_types();
}
// TODO(amprocte): Update this to allow only boolean, for consistency with logical binops.
void op::v0::Not::validate_and_infer_types()
{
auto args_et_pshape = validate_and_infer_elementwise_args();
element::Type& args_et = std::get<0>(args_et_pshape);
PartialShape& args_pshape = std::get<1>(args_et_pshape);
set_output_type(0, args_et, args_pshape);
}
shared_ptr<Node> op::v0::Not::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<v0::Not>(new_args.at(0));
}
......@@ -22,24 +22,51 @@ namespace ngraph
{
namespace op
{
/// \brief Elementwise logical negation operation.
class Not : public Op
namespace v1
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Not", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a logical negation operation.
Not() = default;
/// \brief Constructs a logical negation operation.
///
/// \param arg Node that produces the input tensor.
Not(const Output<Node>& arg);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
}
}
/// \brief Elementwise logical negation operation.
class LogicalNot : public Op
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"LogicalNot", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a logical negation operation.
LogicalNot() = default;
/// \brief Constructs a logical negation operation.
///
/// \param arg Node that produces the input tensor.
LogicalNot(const Output<Node>& arg);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
}
namespace v0
{
/// \brief Elementwise logical negation operation.
class Not : public Op
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Not", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a logical negation operation.
Not() = default;
/// \brief Constructs a logical negation operation.
///
/// \param arg Node that produces the input tensor.
Not(const Output<Node>& arg);
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
}
using v0::Not;
} // namespace op
} // namespace ngraph
......@@ -101,7 +101,12 @@ NGRAPH_OP(Greater, ngraph::op)
NGRAPH_OP(GreaterEq, ngraph::op)
NGRAPH_OP(Less, ngraph::op)
NGRAPH_OP(LessEq, ngraph::op)
NGRAPH_OP(LessEqual, ngraph::op)
NGRAPH_OP(Log, ngraph::op)
NGRAPH_OP(LogicalAnd, ngraph::op)
NGRAPH_OP(LogicalNot, ngraph::op)
NGRAPH_OP(LogicalOr, ngraph::op)
NGRAPH_OP(LogicalXor, ngraph::op)
NGRAPH_OP(LRN, ngraph::op)
NGRAPH_OP(Max, ngraph::op)
NGRAPH_OP(Maximum, ngraph::op)
......
......@@ -19,18 +19,34 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::Or::type_info;
constexpr NodeTypeInfo op::v1::LogicalOr::type_info;
op::Or::Or(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
op::v1::LogicalOr::LogicalOr(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseLogical(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::Or::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v1::LogicalOr::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Or>(new_args.at(0), new_args.at(1), this->get_autob());
return make_shared<v1::LogicalOr>(new_args.at(0), new_args.at(1), this->get_autob());
}
constexpr NodeTypeInfo op::v0::Or::type_info;
op::v0::Or::Or(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseLogical(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v0::Or::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<v0::Or>(new_args.at(0), new_args.at(1), this->get_autob());
}
......@@ -24,32 +24,68 @@ namespace ngraph
{
namespace op
{
/// \brief Elementwise logical-or operation.
///
class Or : public util::BinaryElementwiseLogical
namespace v1
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Or", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a logical-or operation.
/// \brief Elementwise logical-or operation.
///
/// \param arg0 Node that produces the first input tensor.<br>
/// `[d0, ...]`
/// \param arg1 Node that produces the second input tensor.<br>
/// `[d0, ...]`
/// \param auto_broadcast Auto broadcast specification
///
/// Output `[d0, ...]`
class LogicalOr : public util::BinaryElementwiseLogical
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"LogicalOr", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a logical-or operation.
///
/// \param arg0 Node that produces the first input tensor.<br>
/// `[d0, ...]`
/// \param arg1 Node that produces the second input tensor.<br>
/// `[d0, ...]`
/// \param auto_broadcast Auto broadcast specification
///
/// Output `[d0, ...]`
///
LogicalOr(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
};
} // namespace v1
namespace v0
{
/// \brief Elementwise logical-or operation.
///
Or(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
class Or : public util::BinaryElementwiseLogical
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Or", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a logical-or operation.
///
/// \param arg0 Node that produces the first input tensor.<br>
/// `[d0, ...]`
/// \param arg1 Node that produces the second input tensor.<br>
/// `[d0, ...]`
/// \param auto_broadcast Auto broadcast specification
///
/// Output `[d0, ...]`
///
Or(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
};
} // namespace v0
virtual bool is_commutative() const override { return true; }
};
}
}
using v0::Or;
} // namespace op
} // namespace ngraph
......@@ -19,18 +19,34 @@
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::Xor::type_info;
constexpr NodeTypeInfo op::v1::LogicalXor::type_info;
op::Xor::Xor(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
op::v1::LogicalXor::LogicalXor(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseLogical(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::Xor::copy_with_new_args(const NodeVector& new_args) const
shared_ptr<Node> op::v1::LogicalXor::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Xor>(new_args.at(0), new_args.at(1), this->get_autob());
return make_shared<v1::LogicalXor>(new_args.at(0), new_args.at(1), this->get_autob());
}
constexpr NodeTypeInfo op::v0::Xor::type_info;
op::v0::Xor::Xor(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast)
: BinaryElementwiseLogical(arg0, arg1, auto_broadcast)
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v0::Xor::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<v0::Xor>(new_args.at(0), new_args.at(1), this->get_autob());
}
......@@ -24,32 +24,69 @@ namespace ngraph
{
namespace op
{
/// \brief Elementwise logical-xor operation.
///
class Xor : public util::BinaryElementwiseLogical
namespace v1
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Xor", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a logical-xor operation.
/// \brief Elementwise logical-xor operation.
///
/// \param arg0 Node that produces the first input tensor.<br>
/// `[d0, ...]`
/// \param arg1 Node that produces the second input tensor.<br>
/// `[d0, ...]`
/// \param auto_broadcast Auto broadcast specification
///
/// Output `[d0, ...]`
class LogicalXor : public util::BinaryElementwiseLogical
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"LogicalXor", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a logical-xor operation.
///
/// \param arg0 Node that produces the first input tensor.<br>
/// `[d0, ...]`
/// \param arg1 Node that produces the second input tensor.<br>
/// `[d0, ...]`
/// \param auto_broadcast Auto broadcast specification
///
/// Output `[d0, ...]`
///
LogicalXor(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY));
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
};
} // namespace v1
namespace v0
{
/// \brief Elementwise logical-xor operation.
///
Xor(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
class Xor : public util::BinaryElementwiseLogical
{
public:
NGRAPH_API
static constexpr NodeTypeInfo type_info{"Xor", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
/// \brief Constructs a logical-xor operation.
///
/// \param arg0 Node that produces the first input tensor.<br>
/// `[d0, ...]`
/// \param arg1 Node that produces the second input tensor.<br>
/// `[d0, ...]`
/// \param auto_broadcast Auto broadcast specification
///
/// Output `[d0, ...]`
///
Xor(const Output<Node>& arg0,
const Output<Node>& arg1,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastSpec());
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
};
} // namespace v0
virtual bool is_commutative() const override { return true; }
};
}
}
// default opset version
using v0::Xor;
} // namespace op
} // namespace ngraph
......@@ -85,6 +85,17 @@ static shared_ptr<op::Constant> fold_constant_binary_logical(shared_ptr<op::Cons
and_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto logical_xor_node = as_type_ptr<op::v1::LogicalXor>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::logical_xor<char>(a->get_data_ptr<char>(),
b->get_data_ptr<char>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
logical_xor_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto or_node = as_type_ptr<op::Or>(binary))
{
vector<char> out_vec(shape_size(out_shape));
......@@ -96,7 +107,7 @@ static shared_ptr<op::Constant> fold_constant_binary_logical(shared_ptr<op::Cons
or_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto xor_node = as_type_ptr<op::Xor>(binary))
else if (auto xor_node = as_type_ptr<op::v0::Xor>(binary))
{
vector<char> out_vec(shape_size(out_shape));
runtime::reference::logical_xor<char>(a->get_data_ptr<char>(),
......
......@@ -18,6 +18,7 @@
#include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
......@@ -25,7 +26,10 @@
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/less_eq.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/not.hpp"
#include "ngraph/op/or.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/reduce_prod.hpp"
......@@ -35,6 +39,7 @@
#include "ngraph/op/slice.hpp"
#include "ngraph/op/strided_slice.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/xor.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/slice_plan.hpp"
......@@ -266,6 +271,52 @@ bool pass::Opset0Downgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::LessEqual:
{
auto less_eq_v1 = as_type_ptr<op::v1::LessEqual>(node);
auto replacement_node = make_shared<op::v0::LessEq>(node->input(0).get_source_output(),
node->input(1).get_source_output(),
less_eq_v1->get_autob());
replace_node(node, replacement_node);
modified = true;
break;
}
case OP_TYPEID::LogicalAnd:
{
auto and_v1 = as_type_ptr<op::v1::LogicalAnd>(node);
auto replacement_node = make_shared<op::v0::And>(node->input(0).get_source_output(),
node->input(1).get_source_output(),
and_v1->get_autob());
replace_node(node, replacement_node);
modified = true;
break;
}
case OP_TYPEID::LogicalNot:
{
replace_node(node, make_shared<op::v0::Not>(node->input(0).get_source_output()));
modified = true;
break;
}
case OP_TYPEID::LogicalOr:
{
auto or_v1 = as_type_ptr<op::v1::LogicalOr>(node);
auto replacement_node = make_shared<op::v0::Or>(node->input(0).get_source_output(),
node->input(1).get_source_output(),
or_v1->get_autob());
replace_node(node, replacement_node);
modified = true;
break;
}
case OP_TYPEID::LogicalXor:
{
auto xor_v1 = as_type_ptr<op::v1::LogicalXor>(node);
auto replacement_node = make_shared<op::v0::Xor>(node->input(0).get_source_output(),
node->input(1).get_source_output(),
xor_v1->get_autob());
replace_node(node, replacement_node);
modified = true;
break;
}
case OP_TYPEID::MaxPool:
{
auto tmp = as_type_ptr<op::v1::MaxPool>(node);
......
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/pass/opset1_upgrade.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
......@@ -22,7 +23,10 @@
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/less_eq.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/not.hpp"
#include "ngraph/op/or.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/reduce_prod.hpp"
......@@ -34,6 +38,7 @@
#include "ngraph/op/strided_slice.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/op/xor.hpp"
#include <limits>
#include <numeric>
......@@ -97,6 +102,16 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
#endif
switch (get_typeid(node))
{
case OP_TYPEID::And:
{
const auto and_v0 = dynamic_cast<const op::v0::And*>(node.get());
auto replacement_node = make_shared<op::v1::LogicalAnd>(node->input(0).get_source_output(),
node->input(1).get_source_output(),
and_v0->get_autob());
replace_node(node, replacement_node);
modified = true;
break;
}
case OP_TYPEID::AvgPool:
{
auto tmp = dynamic_cast<const op::v0::AvgPool*>(node.get());
......@@ -290,6 +305,16 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::LessEq:
{
const auto less_eq_v0 = dynamic_cast<const op::v0::LessEq*>(node.get());
auto replacement_node = make_shared<op::v1::LessEqual>(node->input(0).get_source_output(),
node->input(1).get_source_output(),
less_eq_v0->get_autob());
replace_node(node, replacement_node);
modified = true;
break;
}
case OP_TYPEID::MaxPool:
{
auto tmp = dynamic_cast<const op::v0::MaxPool*>(node.get());
......@@ -347,6 +372,22 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::Not:
{
replace_node(node, make_shared<op::v1::LogicalNot>(node->input(0).get_source_output()));
modified = true;
break;
}
case OP_TYPEID::Or:
{
const auto or_v0 = dynamic_cast<const op::v0::Or*>(node.get());
auto replacement_node = make_shared<op::v1::LogicalOr>(node->input(0).get_source_output(),
node->input(1).get_source_output(),
or_v0->get_autob());
replace_node(node, replacement_node);
modified = true;
break;
}
case OP_TYPEID::Pad:
{
auto tmp = dynamic_cast<const op::v0::Pad*>(node.get());
......@@ -486,6 +527,16 @@ bool pass::Opset1Upgrade::run_on_node(shared_ptr<Node> node)
modified = true;
break;
}
case OP_TYPEID::Xor:
{
const auto xor_v0 = dynamic_cast<const op::v0::Xor*>(node.get());
auto replacement_node = make_shared<op::v1::LogicalXor>(node->input(0).get_source_output(),
node->input(1).get_source_output(),
xor_v0->get_autob());
replace_node(node, replacement_node);
modified = true;
break;
}
default: break;
}
......
......@@ -247,6 +247,29 @@ namespace ngraph
functors.emplace_back(functor);
}
template <>
void Builder::BUILDER_DECL(ngraph::op::v1::LogicalXor)
{
(void)node;
auto& functors = external_function->get_functors();
auto element_count = out[0].get_size();
auto arg0_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto arg1_buffer_index = external_function->get_buffer_index(args[1].get_name());
auto out0_buffer_index = external_function->get_buffer_index(out[0].get_name());
auto functor =
[&, element_count, arg0_buffer_index, arg1_buffer_index, out0_buffer_index](
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
runtime::cpu::kernel::logical_xor(ctx->buffer_data[arg0_buffer_index],
ctx->buffer_data[arg1_buffer_index],
ctx->buffer_data[out0_buffer_index],
element_count,
ectx->arena);
};
functors.emplace_back(functor);
}
template <>
void Builder::BUILDER_DECL(ngraph::op::Xor)
{
......
......@@ -28,7 +28,6 @@
#include "ngraph/op/add.hpp"
#include "ngraph/op/all.hpp"
#include "ngraph/op/allreduce.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/any.hpp"
#include "ngraph/op/argmax.hpp"
#include "ngraph/op/argmin.hpp"
......@@ -70,7 +69,6 @@
#include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp"
#include "ngraph/op/less.hpp"
#include "ngraph/op/less_eq.hpp"
#include "ngraph/op/log.hpp"
#include "ngraph/op/lrn.hpp"
#include "ngraph/op/max.hpp"
......@@ -80,11 +78,9 @@
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/not.hpp"
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/one_hot.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/or.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/power.hpp"
......@@ -111,7 +107,6 @@
#include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/op/xor.hpp"
#include "ngraph/runtime/cpu/cpu_executor.hpp"
#include "ngraph/runtime/cpu/cpu_kernel_emitters.hpp"
#include "ngraph/runtime/cpu/cpu_op_annotations.hpp"
......@@ -4169,6 +4164,17 @@ namespace ngraph
<< " " << out[0].get_size() << ");\n";
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::v1::LogicalXor)
{
(void)external_function;
(void)node;
writer << "reference::logical_xor(" << args[0].get_name() << ",\n"
<< " " << args[1].get_name() << ",\n"
<< " " << out[0].get_name() << ",\n"
<< " " << out[0].get_size() << ");\n";
}
template <>
void CPU_Emitter::EMITTER_DECL(ngraph::op::Xor)
{
......
......@@ -21,20 +21,25 @@
#include "ngraph/code_writer.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/convolution.hpp"
#include "ngraph/op/experimental/generate_mask.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/less_eq.hpp"
#include "ngraph/op/max.hpp"
#include "ngraph/op/max_pool.hpp"
#include "ngraph/op/min.hpp"
#include "ngraph/op/not.hpp"
#include "ngraph/op/or.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/op/xor.hpp"
#include "ngraph/runtime/cpu/cpu_external_function.hpp"
#include "ngraph/runtime/cpu/cpu_tensor_view_wrapper.hpp"
#include "ngraph/runtime/cpu/op/gelu_backprop.hpp"
......@@ -73,7 +78,6 @@ namespace ngraph
class Greater;
class GreaterEq;
class Less;
class LessEq;
class Any;
class All;
class LRN;
......@@ -126,7 +130,6 @@ namespace ngraph
class ConvolutionBiasAdd;
class ConvolutionAdd;
class ConvolutionBiasBackpropFiltersBias;
class Not;
class QuantizedMaxPool;
class QuantizedAvgPool;
class MaxPoolWithIndices;
......@@ -142,9 +145,6 @@ namespace ngraph
class SigmoidMultiply;
class SigmoidMultiplyBackprop;
class Result;
class And;
class Or;
class Xor;
class CompiledKernel;
class Dropout;
class Dequantize;
......
......@@ -168,9 +168,9 @@ namespace ngraph
{
class GCPUBackend;
class GCPUExecutable;
}
}
}
} // namespace gcpu
} // namespace runtime
} // namespace ngraph
class ngraph::runtime::gcpu::GCPUExecutable : public Executable
{
......@@ -1606,6 +1606,7 @@ private:
}
break;
}
case OP_TYPEID::LogicalXor:
case OP_TYPEID::Xor:
{
size_t element_count = shape_size(node.get_output_shape(0));
......
......@@ -1006,6 +1006,17 @@ private:
less_eq->get_autob());
break;
}
case OP_TYPEID::LessEqual:
{
auto less_eq = static_cast<const op::v1::LessEqual*>(&node);
reference::less_eq<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<char>(),
node.get_input_shape(0),
node.get_input_shape(1),
less_eq->get_autob());
break;
}
case OP_TYPEID::Log:
{
size_t element_count = shape_size(node.get_output_shape(0));
......@@ -1013,6 +1024,39 @@ private:
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::LogicalAnd:
{
auto logical_and = static_cast<const op::v1::LogicalAnd*>(&node);
reference::logical_and(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
logical_and->get_autob());
break;
}
case OP_TYPEID::LogicalOr:
{
auto logical_or = static_cast<const op::v1::LogicalOr*>(&node);
reference::logical_or(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
logical_or->get_autob());
break;
}
case OP_TYPEID::LogicalXor:
{
auto logical_xor = static_cast<const op::v1::LogicalXor*>(&node);
reference::logical_xor(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
logical_xor->get_autob());
break;
}
case OP_TYPEID::LRN:
{
const op::LRN* lrn = static_cast<const op::LRN*>(&node);
......@@ -1116,6 +1160,7 @@ private:
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::LogicalNot:
case OP_TYPEID::Not:
{
size_t element_count = shape_size(node.get_output_shape(0));
......
......@@ -1556,7 +1556,13 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
}
case OP_TYPEID::LessEq:
{
node = make_shared<op::LessEq>(
node = make_shared<op::v0::LessEq>(
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
break;
}
case OP_TYPEID::LessEqual:
{
node = make_shared<op::v1::LessEqual>(
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
break;
}
......@@ -1565,6 +1571,29 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
node = make_shared<op::Log>(args[0]);
break;
}
case OP_TYPEID::LogicalAnd:
{
node = make_shared<op::v1::LogicalAnd>(
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
break;
}
case OP_TYPEID::LogicalNot:
{
node = make_shared<op::v1::LogicalNot>(args[0]);
break;
}
case OP_TYPEID::LogicalOr:
{
node = make_shared<op::v1::LogicalOr>(
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
break;
}
case OP_TYPEID::LogicalXor:
{
node = make_shared<op::v1::LogicalXor>(
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
break;
}
case OP_TYPEID::LRN:
{
auto alpha = node_js.at("alpha").get<double>();
......@@ -1820,7 +1849,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
}
case OP_TYPEID::Or:
{
node = make_shared<op::Or>(
node = make_shared<op::v0::Or>(
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
break;
}
......@@ -2355,7 +2384,7 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
}
case OP_TYPEID::Xor:
{
node = make_shared<op::Xor>(
node = make_shared<op::v0::Xor>(
args[0], args[1], read_auto_broadcast(node_js, "auto_broadcast"));
break;
}
......@@ -3081,7 +3110,16 @@ json JSONSerializer::serialize_node(const Node& n)
}
case OP_TYPEID::LessEq:
{
auto tmp = static_cast<const op::LessEq*>(&n);
auto tmp = static_cast<const op::v0::LessEq*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["auto_broadcast"] = write_auto_broadcast(tmp->get_autob());
}
break;
}
case OP_TYPEID::LessEqual:
{
auto tmp = static_cast<const op::v1::LessEqual*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["auto_broadcast"] = write_auto_broadcast(tmp->get_autob());
......@@ -3090,6 +3128,35 @@ json JSONSerializer::serialize_node(const Node& n)
}
case OP_TYPEID::Log: { break;
}
case OP_TYPEID::LogicalAnd:
{
auto tmp = static_cast<const op::v1::LogicalAnd*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["auto_broadcast"] = write_auto_broadcast(tmp->get_autob());
}
break;
}
case OP_TYPEID::LogicalNot: { break;
}
case OP_TYPEID::LogicalOr:
{
auto tmp = static_cast<const op::v1::LogicalOr*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["auto_broadcast"] = write_auto_broadcast(tmp->get_autob());
}
break;
}
case OP_TYPEID::LogicalXor:
{
auto tmp = static_cast<const op::v1::LogicalXor*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["auto_broadcast"] = write_auto_broadcast(tmp->get_autob());
}
break;
}
case OP_TYPEID::LRN:
{
auto tmp = static_cast<const op::LRN*>(&n);
......@@ -3248,7 +3315,7 @@ json JSONSerializer::serialize_node(const Node& n)
}
case OP_TYPEID::Or:
{
auto tmp = static_cast<const op::Or*>(&n);
auto tmp = static_cast<const op::v0::Or*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["auto_broadcast"] = write_auto_broadcast(tmp->get_autob());
......@@ -3613,7 +3680,7 @@ json JSONSerializer::serialize_node(const Node& n)
}
case OP_TYPEID::Xor:
{
auto tmp = static_cast<const op::Xor*>(&n);
auto tmp = static_cast<const op::v0::Xor*>(&n);
if (tmp->get_autob().m_type != op::AutoBroadcastType::NONE)
{
node["auto_broadcast"] = write_auto_broadcast(tmp->get_autob());
......
......@@ -72,6 +72,11 @@ set(SRC
opset_pass/broadcast_opset_pass.cpp
opset_pass/convolution_opset_pass.cpp
opset_pass/dyn_reshape_opset_pass.cpp
opset_pass/logical_and_opset_pass.cpp
opset_pass/logical_not_opset_pass.cpp
opset_pass/logical_or_opset_pass.cpp
opset_pass/logical_less_equal_opset_pass.cpp
opset_pass/logical_xor_opset_pass.cpp
opset_pass/gather_opset_pass.cpp
opset_pass/generate_mask_opset_pass.cpp
opset_pass/pad_opset_pass.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/pass/opset1_upgrade.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(opset_transform, opset1_logical_and_upgrade_pass)
{
const auto a = make_shared<op::Parameter>(element::boolean, Shape{5, 10, 15});
const auto b = make_shared<op::Parameter>(element::boolean, Shape{5, 10, 15});
const auto and_v0 = make_shared<op::v0::And>(a, b);
const auto result = make_shared<op::Result>(and_v0);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{a, b});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset1Upgrade>();
pass_manager.run_passes(f);
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto and_v1 = static_pointer_cast<op::v1::LogicalAnd>(pass_replacement_node);
EXPECT_EQ(and_v1->description(), "LogicalAnd");
EXPECT_EQ(and_v1->get_version(), 1);
const auto values_out_element_type = and_v1->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean);
}
TEST(opset_transform, opset1_logical_and_downgrade_pass)
{
const auto a = make_shared<op::Parameter>(element::boolean, Shape{5, 10, 15});
const auto b = make_shared<op::Parameter>(element::boolean, Shape{5, 10, 15});
const auto and_v1 = make_shared<op::v1::LogicalAnd>(a, b);
const auto result = make_shared<op::Result>(and_v1);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{a, b});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto and_v0 = static_pointer_cast<op::v0::And>(pass_replacement_node);
EXPECT_EQ(and_v0->description(), "And");
EXPECT_EQ(and_v0->get_version(), 0);
const auto values_out_element_type = and_v0->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean);
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/pass/opset1_upgrade.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(opset_transform, opset1_logical_less_equal_upgrade_pass)
{
const auto a = make_shared<op::Parameter>(element::boolean, Shape{5, 10, 15});
const auto b = make_shared<op::Parameter>(element::boolean, Shape{5, 10, 15});
const auto less_eq_v0 = make_shared<op::v0::LessEq>(a, b);
const auto result = make_shared<op::Result>(less_eq_v0);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{a, b});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset1Upgrade>();
pass_manager.run_passes(f);
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto less_eq_v1 = static_pointer_cast<op::v1::LessEqual>(pass_replacement_node);
EXPECT_EQ(less_eq_v1->description(), "LessEqual");
EXPECT_EQ(less_eq_v1->get_version(), 1);
const auto values_out_element_type = less_eq_v1->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean);
}
TEST(opset_transform, opset1_logical_less_equal_downgrade_pass)
{
const auto a = make_shared<op::Parameter>(element::boolean, Shape{5, 10, 15});
const auto b = make_shared<op::Parameter>(element::boolean, Shape{5, 10, 15});
const auto less_eq_v1 = make_shared<op::v1::LessEqual>(a, b);
const auto result = make_shared<op::Result>(less_eq_v1);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{a, b});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto less_eq_v0 = static_pointer_cast<op::v0::LessEq>(pass_replacement_node);
EXPECT_EQ(less_eq_v0->description(), "LessEq");
EXPECT_EQ(less_eq_v0->get_version(), 0);
const auto values_out_element_type = less_eq_v0->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean);
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions not
// limitations under the License.
//*****************************************************************************
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/pass/opset1_upgrade.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(opset_transform, opset1_logical_not_upgrade_pass)
{
const auto a = make_shared<op::Parameter>(element::boolean, Shape{5, 10, 15});
const auto not_v0 = make_shared<op::v0::Not>(a);
const auto result = make_shared<op::Result>(not_v0);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{a});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset1Upgrade>();
pass_manager.run_passes(f);
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto not_v1 = static_pointer_cast<op::v1::LogicalNot>(pass_replacement_node);
EXPECT_EQ(not_v1->description(), "LogicalNot");
EXPECT_EQ(not_v1->get_version(), 1);
const auto values_out_element_type = not_v1->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean);
}
TEST(opset_transform, opset1_logical_not_downgrade_pass)
{
const auto a = make_shared<op::Parameter>(element::boolean, Shape{5, 10, 15});
const auto not_v1 = make_shared<op::v1::LogicalNot>(a);
const auto result = make_shared<op::Result>(not_v1);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{a});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto not_v0 = static_pointer_cast<op::v0::Not>(pass_replacement_node);
EXPECT_EQ(not_v0->description(), "Not");
EXPECT_EQ(not_v0->get_version(), 0);
const auto values_out_element_type = not_v0->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean);
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/pass/opset1_upgrade.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(opset_transform, opset1_logical_or_upgrade_pass)
{
const auto a = make_shared<op::Parameter>(element::boolean, Shape{5, 10, 15});
const auto b = make_shared<op::Parameter>(element::boolean, Shape{5, 10, 15});
const auto or_v0 = make_shared<op::v0::Or>(a, b);
const auto result = make_shared<op::Result>(or_v0);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{a, b});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset1Upgrade>();
pass_manager.run_passes(f);
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto or_v1 = static_pointer_cast<op::v1::LogicalOr>(pass_replacement_node);
EXPECT_EQ(or_v1->description(), "LogicalOr");
EXPECT_EQ(or_v1->get_version(), 1);
const auto values_out_element_type = or_v1->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean);
}
TEST(opset_transform, opset1_logical_or_downgrade_pass)
{
const auto a = make_shared<op::Parameter>(element::boolean, Shape{5, 10, 15});
const auto b = make_shared<op::Parameter>(element::boolean, Shape{5, 10, 15});
const auto or_v1 = make_shared<op::v1::LogicalOr>(a, b);
const auto result = make_shared<op::Result>(or_v1);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{a, b});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto or_v0 = static_pointer_cast<op::v0::Or>(pass_replacement_node);
EXPECT_EQ(or_v0->description(), "Or");
EXPECT_EQ(or_v0->get_version(), 0);
const auto values_out_element_type = or_v0->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean);
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/opset0_downgrade.hpp"
#include "ngraph/pass/opset1_upgrade.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(opset_transform, opset1_logical_xor_upgrade_pass)
{
const auto a = make_shared<op::Parameter>(element::boolean, Shape{5, 10, 15});
const auto b = make_shared<op::Parameter>(element::boolean, Shape{5, 10, 15});
const auto xor_v0 = make_shared<op::v0::Xor>(a, b);
const auto result = make_shared<op::Result>(xor_v0);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{a, b});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset1Upgrade>();
pass_manager.run_passes(f);
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto xor_v1 = static_pointer_cast<op::v1::LogicalXor>(pass_replacement_node);
EXPECT_EQ(xor_v1->description(), "LogicalXor");
EXPECT_EQ(xor_v1->get_version(), 1);
const auto values_out_element_type = xor_v1->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean);
}
TEST(opset_transform, opset1_logical_xor_downgrade_pass)
{
const auto a = make_shared<op::Parameter>(element::boolean, Shape{5, 10, 15});
const auto b = make_shared<op::Parameter>(element::boolean, Shape{5, 10, 15});
const auto xor_v1 = make_shared<op::v1::LogicalXor>(a, b);
const auto result = make_shared<op::Result>(xor_v1);
auto f = make_shared<Function>(ResultVector{result}, ParameterVector{a, b});
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<pass::Opset0Downgrade>();
pass_manager.run_passes(f);
const auto pass_replacement_node =
f->get_result()->input(0).get_source_output().get_node_shared_ptr();
const auto xor_v0 = static_pointer_cast<op::v0::Xor>(pass_replacement_node);
EXPECT_EQ(xor_v0->description(), "Xor");
EXPECT_EQ(xor_v0->get_version(), 0);
const auto values_out_element_type = xor_v0->output(0).get_element_type();
EXPECT_EQ(values_out_element_type, element::boolean);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment