Unverified Commit 1b611294 authored by Tomasz Socha's avatar Tomasz Socha Committed by GitHub

Add attribute visitor for ops S-X (#4325)

* Add attribute visitor to xor op

* Add attribute visitor to VariadicSplit op

* Add attribute visitor to Unsqueeze op

* Add attribute visitor to Transpose op

* Add attribute visitor to TopK op

* Add attribute visitor to Tile op

* Add attribute visitor to Tanh op

* Add attribute visitor to Tan op

* Add attribute visitor to StridedSlice op

* Add attribute visitor to Squeeze op

* Add attribute visitor to SquaredDifference op

* Add attribute visitor to Split op

* Add attribute visitor to SpaceToDepth op

* Add attribute visitor to Sqrt op

* Add attribute visitor to Softmax op

* Add attribute visitor to Sinh op

* Add attribute visitor to Sin op

* Add attribute visitor to Sign op

* Add attribute visitor to ShuffleChannels op

* Add attribute visitor to ShapeOf op

* Add attribute visitor to Selu op

* style

* Review Fix I

* Add attribute visitor to TensorIterator op

* Revert "Add attribute visitor to TensorIterator op"

This reverts commit 04850068f3a3e3b6ddcea58023327c76df574fa1.

* Add support for AutoBroadcast in attribute_visitor

* Add tests for operators with attributes
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
Co-authored-by: 's avatarRobert Kimball <robert.kimball@intel.com>
parent faff970a
......@@ -33,6 +33,11 @@ void op::ShapeOf::validate_and_infer_types()
set_output_type(0, element::i64, PartialShape{get_input_partial_shape(0).rank()});
}
bool ngraph::op::v0::ShapeOf::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
shared_ptr<Node> op::ShapeOf::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
......
......@@ -34,6 +34,7 @@ namespace ngraph
/// \brief Constructs a shape-of operation.
ShapeOf(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -29,6 +29,11 @@ op::Tile::Tile(const Output<Node>& data, const Output<Node>& repeats)
constructor_validate_and_infer_types();
}
bool ngraph::op::v0::Tile::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
void op::Tile::validate_and_infer_types()
{
auto arg_et = get_input_element_type(0);
......
......@@ -38,6 +38,7 @@ namespace ngraph
/// \param repeats The node producing the per-dimension replication factor
Tile(const Output<Node>& data, const Output<Node>& repeats);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
......
......@@ -30,6 +30,11 @@ op::v1::Transpose::Transpose(const Output<Node>& arg, const Output<Node>& input_
constructor_validate_and_infer_types();
}
bool ngraph::op::v1::Transpose::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
void op::v1::Transpose::validate_and_infer_types()
{
NODE_VALIDATION_CHECK(this,
......
......@@ -42,6 +42,7 @@ namespace ngraph
/// value must contain every integer in the range [0,n-1].
Transpose(const Output<Node>& arg, const Output<Node>& input_order);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
......
......@@ -34,6 +34,11 @@ op::v0::Selu::Selu(const Output<Node>& data, const Output<Node>& alpha, const Ou
constructor_validate_and_infer_types();
}
bool ngraph::op::v0::Selu::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
NodeVector op::v0::Selu::decompose_op() const
{
const auto data = input_value(0);
......
......@@ -42,6 +42,7 @@ namespace ngraph
const Output<Node>& alpha,
const Output<Node>& lambda);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node>
......
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/fused/shuffle_channels.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/reshape.hpp"
using namespace std;
......@@ -30,6 +31,13 @@ op::ShuffleChannels::ShuffleChannels(const Output<Node>& data, const int axis, c
constructor_validate_and_infer_types();
}
bool ngraph::op::v0::ShuffleChannels::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("axis", m_axis);
visitor.on_attribute("groups", m_groups);
return true;
}
size_t op::ShuffleChannels::get_zero_based_axis() const
{
if (m_axis >= 0)
......
......@@ -48,6 +48,7 @@ namespace ngraph
const int axis = 1,
const size_t groups = 1UL);
bool visit_attributes(AttributeVisitor& visitor) override;
size_t get_zero_based_axis() const;
virtual void pre_validate_and_infer_types() override;
......
......@@ -17,6 +17,7 @@
#include <cstddef>
#include <memory>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/shape.hpp"
#include "space_to_depth.hpp"
......@@ -41,6 +42,13 @@ op::SpaceToDepth::SpaceToDepth(const Output<Node>& data, const std::string& mode
{
}
bool ngraph::op::v0::SpaceToDepth::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("block_size", m_blocksize);
visitor.on_attribute("mode", m_mode);
return true;
}
NodeVector op::SpaceToDepth::decompose_op() const
{
auto data = input_value(0);
......@@ -153,3 +161,24 @@ op::SpaceToDepth::SpaceToDepthMode op::SpaceToDepth::mode_from_string(const std:
return allowed_values.at(mode);
}
namespace ngraph
{
template <>
EnumNames<op::v0::SpaceToDepth::SpaceToDepthMode>&
EnumNames<op::v0::SpaceToDepth::SpaceToDepthMode>::get()
{
static auto enum_names = EnumNames<op::v0::SpaceToDepth::SpaceToDepthMode>(
"op::v0::SpaceToDepth::SpaceToDepthMode",
{{"blocks_first", op::v0::SpaceToDepth::SpaceToDepthMode::BLOCKS_FIRST},
{"depth_first", op::v0::SpaceToDepth::SpaceToDepthMode::DEPTH_FIRST}});
return enum_names;
}
constexpr DiscreteTypeInfo AttributeAdapter<op::v0::SpaceToDepth::SpaceToDepthMode>::type_info;
std::ostream& operator<<(std::ostream& s, const op::v0::SpaceToDepth::SpaceToDepthMode& type)
{
return s << as_string(type);
}
} // namespace ngraph
......@@ -60,6 +60,7 @@ namespace ngraph
const std::string& mode,
std::size_t block_size = 1);
bool visit_attributes(AttributeVisitor& visitor) override;
std::size_t get_block_size() const { return m_blocksize; }
SpaceToDepthMode get_mode() const { return m_mode; }
virtual NodeVector decompose_op() const override;
......@@ -74,5 +75,22 @@ namespace ngraph
};
}
using v0::SpaceToDepth;
}
}
} // namespace op
std::ostream& operator<<(std::ostream& s, const op::v0::SpaceToDepth::SpaceToDepthMode& type);
template <>
class NGRAPH_API AttributeAdapter<op::v0::SpaceToDepth::SpaceToDepthMode>
: public EnumAttributeAdapterBase<op::v0::SpaceToDepth::SpaceToDepthMode>
{
public:
AttributeAdapter(op::v0::SpaceToDepth::SpaceToDepthMode& value)
: EnumAttributeAdapterBase<op::v0::SpaceToDepth::SpaceToDepthMode>(value)
{
}
static constexpr DiscreteTypeInfo type_info{
"AttributeAdapter<op::v0::SpaceToDepth::SpaceToDepthMode>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
} // namespace ngraph
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include <numeric>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/split.hpp"
......@@ -121,6 +122,12 @@ op::v1::Split::Split(const Output<Node>& data, const Output<Node>& axis, const s
constructor_validate_and_infer_types();
}
bool ngraph::op::v1::Split::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("num_splits", m_num_splits);
return true;
}
void op::v1::Split::validate_and_infer_types()
{
const auto data_ps = input_value(0).get_partial_shape();
......
......@@ -95,6 +95,7 @@ namespace ngraph
/// split into.
Split(const Output<Node>& data, const Output<Node>& axis, const size_t num_splits);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/fused/squared_difference.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/subtract.hpp"
......@@ -34,6 +35,12 @@ op::SquaredDifference::SquaredDifference(const Output<Node>& x1,
constructor_validate_and_infer_types();
}
bool ngraph::op::v0::SquaredDifference::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("auto_broadcast", m_autobroadcast);
return true;
}
NodeVector op::SquaredDifference::decompose_op() const
{
const auto x1 = input_value(0);
......
......@@ -45,6 +45,7 @@ namespace ngraph
const Output<Node>& x2,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastType::NUMPY);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node>
......
......@@ -90,6 +90,11 @@ void op::Squeeze::pre_validate_and_infer_types()
set_output_type(0, get_input_element_type(0), output_data_shape);
}
bool ngraph::op::v0::Squeeze::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
NodeVector op::Squeeze::decompose_op() const
{
NODE_VALIDATION_CHECK(
......
......@@ -37,6 +37,7 @@ namespace ngraph
Squeeze() = default;
Squeeze(const Output<Node>& data, const Output<Node>& axes);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual NodeVector decompose_op() const override;
virtual void pre_validate_and_infer_types() override;
......
......@@ -89,6 +89,11 @@ NodeVector op::Unsqueeze::decompose_op() const
return {make_shared<ngraph::op::Reshape>(data, input_order, output_shape)};
}
bool ngraph::op::v0::Unsqueeze::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
shared_ptr<Node> op::Unsqueeze::copy_with_new_args(const NodeVector& new_args) const
{
if (new_args.size() != 2)
......
......@@ -40,6 +40,8 @@ namespace ngraph
virtual void pre_validate_and_infer_types() override;
virtual NodeVector decompose_op() const override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
......
......@@ -27,6 +27,11 @@ op::Sign::Sign(const Output<Node>& arg)
constructor_validate_and_infer_types();
}
bool ngraph::op::v0::Sign::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
shared_ptr<Node> op::Sign::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
......
......@@ -37,6 +37,7 @@ namespace ngraph
/// \param arg Node that produces the input tensor.
Sign(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};
......
......@@ -29,6 +29,11 @@ op::Sin::Sin(const Output<Node>& arg)
constructor_validate_and_infer_types();
}
bool ngraph::op::v0::Sin::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
shared_ptr<Node> op::Sin::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
......
......@@ -50,6 +50,7 @@ namespace ngraph
Sin(const Output<Node>& arg);
Sin() = default;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -29,6 +29,11 @@ op::Sinh::Sinh(const Output<Node>& arg)
constructor_validate_and_infer_types();
}
bool ngraph::op::v0::Sinh::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
shared_ptr<Node> op::Sinh::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
......
......@@ -36,6 +36,7 @@ namespace ngraph
Sinh(const Output<Node>& arg);
Sinh() = default;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -18,6 +18,7 @@
#include <algorithm>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/multiply.hpp"
......@@ -164,6 +165,12 @@ op::v1::Softmax::Softmax(const Output<Node>& arg, const size_t axis)
constructor_validate_and_infer_types();
}
bool ngraph::op::v1::Softmax::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("axis", m_axis);
return true;
}
void op::v1::Softmax::validate_and_infer_types()
{
const PartialShape& input_shape = get_input_partial_shape(0);
......
......@@ -88,6 +88,7 @@ namespace ngraph
///
Softmax(const Output<Node>& arg, const size_t axis);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
size_t get_version() const override { return 1; }
......
......@@ -29,6 +29,11 @@ op::Sqrt::Sqrt(const Output<Node>& arg)
constructor_validate_and_infer_types();
}
bool ngraph::op::v0::Sqrt::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
shared_ptr<Node> op::Sqrt::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
......
......@@ -50,6 +50,7 @@ namespace ngraph
Sqrt(const Output<Node>& arg);
Sqrt() = default;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/strided_slice.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/validation_util.hpp"
......@@ -66,6 +67,16 @@ op::v1::StridedSlice::StridedSlice(const Output<Node>& data,
{
}
bool ngraph::op::v1::StridedSlice::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("begin_mask", m_begin_mask);
visitor.on_attribute("end_mask", m_end_mask);
visitor.on_attribute("new_axis_mask", m_new_axis_mask);
visitor.on_attribute("shrink_axis_mask", m_shrink_axis_mask);
visitor.on_attribute("ellipsis_mask", m_ellipsis_mask);
return true;
}
void op::v1::StridedSlice::validate_and_infer_types()
{
const auto& begin_mask_et = get_input_element_type(1);
......
......@@ -90,6 +90,7 @@ namespace ngraph
const std::vector<int64_t>& shrink_axis_mask = std::vector<int64_t>{},
const std::vector<int64_t>& ellipsis_mask = std::vector<int64_t>{});
bool visit_attributes(AttributeVisitor& visitor) override;
const std::vector<int64_t>& get_begin_mask() const { return m_begin_mask; }
const std::vector<int64_t>& get_end_mask() const { return m_end_mask; }
const std::vector<int64_t>& get_new_axis_mask() const { return m_new_axis_mask; }
......
......@@ -30,6 +30,11 @@ op::Tan::Tan(const Output<Node>& arg)
constructor_validate_and_infer_types();
}
bool ngraph::op::v0::Tan::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
shared_ptr<Node> op::Tan::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
......
......@@ -50,6 +50,7 @@ namespace ngraph
Tan(const Output<Node>& arg);
Tan() = default;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -29,6 +29,11 @@ op::Tanh::Tanh(const Output<Node>& arg)
constructor_validate_and_infer_types();
}
bool ngraph::op::v0::Tanh::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
shared_ptr<Node> op::Tanh::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
......
......@@ -36,6 +36,7 @@ namespace ngraph
Tanh(const Output<Node>& arg);
Tanh() = default;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -16,6 +16,7 @@
#include <memory>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/axis_vector.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/topk.hpp"
......@@ -262,6 +263,14 @@ op::v1::TopK::TopK(const Output<Node>& data,
constructor_validate_and_infer_types();
}
bool ngraph::op::v1::TopK::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("axis", m_axis);
visitor.on_attribute("mode", m_mode);
visitor.on_attribute("sort", m_sort);
return true;
}
void op::v1::TopK::validate_and_infer_types()
{
const auto& input_partial_shape = get_input_partial_shape(0);
......@@ -434,3 +443,22 @@ void op::v1::TopK::set_k(size_t k)
this->input(1).replace_source_output(
op::Constant::create(element::i64, Shape{}, {k})->output(0));
}
namespace ngraph
{
template <>
EnumNames<op::v1::TopK::Mode>& EnumNames<op::v1::TopK::Mode>::get()
{
static auto enum_names = EnumNames<op::v1::TopK::Mode>(
"op::v1::TopK::Mode",
{{"max", op::v1::TopK::Mode::MAX}, {"min", op::v1::TopK::Mode::MIN}});
return enum_names;
}
constexpr DiscreteTypeInfo AttributeAdapter<op::v1::TopK::Mode>::type_info;
std::ostream& operator<<(std::ostream& s, const op::v1::TopK::Mode& type)
{
return s << as_string(type);
}
}
......@@ -155,6 +155,7 @@ namespace ngraph
const SortType sort,
const element::Type& index_element_type = element::i32);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
......@@ -208,4 +209,20 @@ namespace ngraph
using v0::TopK;
} // op
std::ostream& operator<<(std::ostream& s, const op::v1::TopK::Mode& type);
template <>
class NGRAPH_API AttributeAdapter<op::v1::TopK::Mode>
: public EnumAttributeAdapterBase<op::v1::TopK::Mode>
{
public:
AttributeAdapter(op::v1::TopK::Mode& value)
: EnumAttributeAdapterBase<op::v1::TopK::Mode>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::v1::TopK::Mode>", 1};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
} // ngraph
......@@ -59,6 +59,6 @@ void op::util::BinaryElementwiseArithmetic::validate_and_infer_types()
bool op::util::BinaryElementwiseArithmetic::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("autob", m_autob);
visitor.on_attribute("auto_broadcast", m_autob);
return true;
}
......@@ -56,6 +56,6 @@ void op::util::BinaryElementwiseLogical::validate_and_infer_types()
bool op::util::BinaryElementwiseLogical::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("autob", m_autob);
visitor.on_attribute("auto_broadcast", m_autob);
return true;
}
......@@ -33,6 +33,11 @@ op::v1::VariadicSplit::VariadicSplit(const Output<Node>& data,
constructor_validate_and_infer_types();
}
bool ngraph::op::v1::VariadicSplit::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
void ngraph::op::v1::VariadicSplit::validate_and_infer_types()
{
set_input_is_relevant_to_value(0);
......
......@@ -48,6 +48,8 @@ namespace ngraph
const Output<Node>& axis,
const Output<Node>& split_lengths);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......
......@@ -35,6 +35,12 @@ shared_ptr<Node> op::v1::LogicalXor::copy_with_new_args(const NodeVector& new_ar
return make_shared<v1::LogicalXor>(new_args.at(0), new_args.at(1), this->get_autob());
}
bool ngraph::op::v1::LogicalXor::visit_attributes(AttributeVisitor& visitor)
{
BinaryElementwiseLogical::visit_attributes(visitor);
return true;
}
constexpr NodeTypeInfo op::v0::Xor::type_info;
op::v0::Xor::Xor(const Output<Node>& arg0,
......
......@@ -53,6 +53,7 @@ namespace ngraph
copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
bool visit_attributes(AttributeVisitor& visitor) override;
};
} // namespace v1
namespace v0
......
......@@ -498,3 +498,136 @@ TEST(attributes, lstm_sequence_op)
EXPECT_EQ(g_lstm_sequence->get_input_forget(), lstm_sequence->get_input_forget());
EXPECT_EQ(g_lstm_sequence->get_weights_format(), lstm_sequence->get_weights_format());
}
TEST(attributes, shuffle_channels_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::ShuffleChannels>();
auto data = make_shared<op::Parameter>(element::i32, Shape{200});
auto axis = 0;
auto groups = 2;
auto shuffle_channels = make_shared<opset1::ShuffleChannels>(data, axis, groups);
NodeBuilder builder(shuffle_channels);
auto g_shuffle_channels = as_type_ptr<opset1::ShuffleChannels>(builder.create());
EXPECT_EQ(g_shuffle_channels->get_axis(), shuffle_channels->get_axis());
EXPECT_EQ(g_shuffle_channels->get_groups(), shuffle_channels->get_groups());
}
TEST(attributes, softmax_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::Softmax>();
auto data = make_shared<op::Parameter>(element::i32, Shape{200});
auto axis = 0;
auto softmax = make_shared<opset1::Softmax>(data, axis);
NodeBuilder builder(softmax);
auto g_softmax = as_type_ptr<opset1::Softmax>(builder.create());
EXPECT_EQ(g_softmax->get_axis(), softmax->get_axis());
}
TEST(attributes, space_to_depth_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::SpaceToDepth>();
auto data = make_shared<op::Parameter>(element::i32, Shape{2, 3, 50, 50});
auto block_size = 2;
auto mode = opset1::SpaceToDepth::SpaceToDepthMode::BLOCKS_FIRST;
auto space_to_depth = make_shared<opset1::SpaceToDepth>(data, mode, block_size);
NodeBuilder builder(space_to_depth);
auto g_space_to_depth = as_type_ptr<opset1::SpaceToDepth>(builder.create());
EXPECT_EQ(g_space_to_depth->get_block_size(), space_to_depth->get_block_size());
EXPECT_EQ(g_space_to_depth->get_mode(), space_to_depth->get_mode());
}
TEST(attributes, split_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::Split>();
auto data = make_shared<op::Parameter>(element::i32, Shape{200});
auto axis = make_shared<op::Parameter>(element::i32, Shape{});
auto num_splits = 2;
auto split = make_shared<opset1::Split>(data, axis, num_splits);
NodeBuilder builder(split);
auto g_split = as_type_ptr<opset1::Split>(builder.create());
EXPECT_EQ(g_split->get_num_splits(), split->get_num_splits());
}
TEST(attributes, squared_difference_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::SquaredDifference>();
auto x1 = make_shared<op::Parameter>(element::i32, Shape{200});
auto x2 = make_shared<op::Parameter>(element::i32, Shape{200});
auto auto_broadcast = op::AutoBroadcastType::NUMPY;
auto squared_difference = make_shared<opset1::SquaredDifference>(x1, x2, auto_broadcast);
NodeBuilder builder(squared_difference);
auto g_squared_difference = as_type_ptr<opset1::SquaredDifference>(builder.create());
EXPECT_EQ(g_squared_difference->get_autob(), squared_difference->get_autob());
}
TEST(attributes, strided_slice_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::StridedSlice>();
auto data = make_shared<op::Parameter>(element::i32, Shape{2, 3, 4, 5});
auto begin = make_shared<op::Parameter>(element::i32, Shape{2});
auto end = make_shared<op::Parameter>(element::i32, Shape{2});
auto stride = make_shared<op::Parameter>(element::i32, Shape{2});
auto begin_mask = std::vector<int64_t>{0, 0};
auto end_mask = std::vector<int64_t>{0, 0};
auto new_axis_mask = std::vector<int64_t>{0, 0};
auto shrink_axis_mask = std::vector<int64_t>{0, 0};
auto ellipsis_mask = std::vector<int64_t>{0, 0};
auto strided_slice = make_shared<opset1::StridedSlice>(data,
begin,
end,
stride,
begin_mask,
end_mask,
new_axis_mask,
shrink_axis_mask,
ellipsis_mask);
NodeBuilder builder(strided_slice);
auto g_strided_slice = as_type_ptr<opset1::StridedSlice>(builder.create());
EXPECT_EQ(g_strided_slice->get_begin_mask(), strided_slice->get_begin_mask());
EXPECT_EQ(g_strided_slice->get_end_mask(), strided_slice->get_end_mask());
EXPECT_EQ(g_strided_slice->get_new_axis_mask(), strided_slice->get_new_axis_mask());
EXPECT_EQ(g_strided_slice->get_shrink_axis_mask(), strided_slice->get_shrink_axis_mask());
EXPECT_EQ(g_strided_slice->get_ellipsis_mask(), strided_slice->get_ellipsis_mask());
}
TEST(attributes, topk_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::TopK>();
auto data = make_shared<op::Parameter>(element::i32, Shape{2, 3, 4, 5});
auto k = make_shared<op::Parameter>(element::i32, Shape{});
auto axis = 0;
auto mode = opset1::TopK::Mode::MAX;
auto sort_type = opset1::TopK::SortType::SORT_VALUES;
auto topk = make_shared<opset1::TopK>(data, k, axis, mode, sort_type);
NodeBuilder builder(topk);
auto g_topk = as_type_ptr<opset1::TopK>(builder.create());
EXPECT_EQ(g_topk->get_axis(), topk->get_axis());
EXPECT_EQ(g_topk->get_mode(), topk->get_mode());
EXPECT_EQ(g_topk->get_sort_type(), topk->get_sort_type());
}
TEST(attributes, logical_xor_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::LogicalXor>();
auto x1 = make_shared<op::Parameter>(element::boolean, Shape{200});
auto x2 = make_shared<op::Parameter>(element::boolean, Shape{200});
auto auto_broadcast = op::AutoBroadcastType::NUMPY;
auto logical_xor = make_shared<opset1::LogicalXor>(x1, x2, auto_broadcast);
NodeBuilder builder(logical_xor);
auto g_logical_xor = as_type_ptr<opset1::LogicalXor>(builder.create());
EXPECT_EQ(g_logical_xor->get_autob(), logical_xor->get_autob());
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment