Unverified Commit 1b611294 authored by Tomasz Socha's avatar Tomasz Socha Committed by GitHub

Add attribute visitor for ops S-X (#4325)

* Add attribute visitor to xor op

* Add attribute visitor to VariadicSplit op

* Add attribute visitor to Unsqueeze op

* Add attribute visitor to Transpose op

* Add attribute visitor to TopK op

* Add attribute visitor to Tile op

* Add attribute visitor to Tanh op

* Add attribute visitor to Tan op

* Add attribute visitor to StridedSlice op

* Add attribute visitor to Squeeze op

* Add attribute visitor to SquaredDifference op

* Add attribute visitor to Split op

* Add attribute visitor to SpaceToDepth op

* Add attribute visitor to Sqrt op

* Add attribute visitor to Softmax op

* Add attribute visitor to Sinh op

* Add attribute visitor to Sin op

* Add attribute visitor to Sign op

* Add attribute visitor to ShuffleChannels op

* Add attribute visitor to ShapeOf op

* Add attribute visitor to Selu op

* style

* Review Fix I

* Add attribute visitor to TensorIterator op

* Revert "Add attribute visitor to TensorIterator op"

This reverts commit 04850068f3a3e3b6ddcea58023327c76df574fa1.

* Add support for AutoBroadcast in attribute_visitor

* Add tests for operators with attributes
Co-authored-by: 's avatarScott Cyphers <diyessi@users.noreply.github.com>
Co-authored-by: 's avatarRobert Kimball <robert.kimball@intel.com>
parent faff970a
...@@ -33,6 +33,11 @@ void op::ShapeOf::validate_and_infer_types() ...@@ -33,6 +33,11 @@ void op::ShapeOf::validate_and_infer_types()
set_output_type(0, element::i64, PartialShape{get_input_partial_shape(0).rank()}); set_output_type(0, element::i64, PartialShape{get_input_partial_shape(0).rank()});
} }
bool ngraph::op::v0::ShapeOf::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
shared_ptr<Node> op::ShapeOf::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::ShapeOf::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -34,6 +34,7 @@ namespace ngraph ...@@ -34,6 +34,7 @@ namespace ngraph
/// \brief Constructs a shape-of operation. /// \brief Constructs a shape-of operation.
ShapeOf(const Output<Node>& arg); ShapeOf(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -29,6 +29,11 @@ op::Tile::Tile(const Output<Node>& data, const Output<Node>& repeats) ...@@ -29,6 +29,11 @@ op::Tile::Tile(const Output<Node>& data, const Output<Node>& repeats)
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v0::Tile::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
void op::Tile::validate_and_infer_types() void op::Tile::validate_and_infer_types()
{ {
auto arg_et = get_input_element_type(0); auto arg_et = get_input_element_type(0);
......
...@@ -38,6 +38,7 @@ namespace ngraph ...@@ -38,6 +38,7 @@ namespace ngraph
/// \param repeats The node producing the per-dimension replication factor /// \param repeats The node producing the per-dimension replication factor
Tile(const Output<Node>& data, const Output<Node>& repeats); Tile(const Output<Node>& data, const Output<Node>& repeats);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -30,6 +30,11 @@ op::v1::Transpose::Transpose(const Output<Node>& arg, const Output<Node>& input_ ...@@ -30,6 +30,11 @@ op::v1::Transpose::Transpose(const Output<Node>& arg, const Output<Node>& input_
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v1::Transpose::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
void op::v1::Transpose::validate_and_infer_types() void op::v1::Transpose::validate_and_infer_types()
{ {
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
......
...@@ -42,6 +42,7 @@ namespace ngraph ...@@ -42,6 +42,7 @@ namespace ngraph
/// value must contain every integer in the range [0,n-1]. /// value must contain every integer in the range [0,n-1].
Transpose(const Output<Node>& arg, const Output<Node>& input_order); Transpose(const Output<Node>& arg, const Output<Node>& input_order);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -34,6 +34,11 @@ op::v0::Selu::Selu(const Output<Node>& data, const Output<Node>& alpha, const Ou ...@@ -34,6 +34,11 @@ op::v0::Selu::Selu(const Output<Node>& data, const Output<Node>& alpha, const Ou
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v0::Selu::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
NodeVector op::v0::Selu::decompose_op() const NodeVector op::v0::Selu::decompose_op() const
{ {
const auto data = input_value(0); const auto data = input_value(0);
......
...@@ -42,6 +42,7 @@ namespace ngraph ...@@ -42,6 +42,7 @@ namespace ngraph
const Output<Node>& alpha, const Output<Node>& alpha,
const Output<Node>& lambda); const Output<Node>& lambda);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/op/fused/shuffle_channels.hpp" #include "ngraph/op/fused/shuffle_channels.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/reshape.hpp" #include "ngraph/builder/reshape.hpp"
using namespace std; using namespace std;
...@@ -30,6 +31,13 @@ op::ShuffleChannels::ShuffleChannels(const Output<Node>& data, const int axis, c ...@@ -30,6 +31,13 @@ op::ShuffleChannels::ShuffleChannels(const Output<Node>& data, const int axis, c
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v0::ShuffleChannels::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("axis", m_axis);
visitor.on_attribute("groups", m_groups);
return true;
}
size_t op::ShuffleChannels::get_zero_based_axis() const size_t op::ShuffleChannels::get_zero_based_axis() const
{ {
if (m_axis >= 0) if (m_axis >= 0)
......
...@@ -48,6 +48,7 @@ namespace ngraph ...@@ -48,6 +48,7 @@ namespace ngraph
const int axis = 1, const int axis = 1,
const size_t groups = 1UL); const size_t groups = 1UL);
bool visit_attributes(AttributeVisitor& visitor) override;
size_t get_zero_based_axis() const; size_t get_zero_based_axis() const;
virtual void pre_validate_and_infer_types() override; virtual void pre_validate_and_infer_types() override;
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <cstddef> #include <cstddef>
#include <memory> #include <memory>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/reshape.hpp" #include "ngraph/builder/reshape.hpp"
#include "ngraph/shape.hpp" #include "ngraph/shape.hpp"
#include "space_to_depth.hpp" #include "space_to_depth.hpp"
...@@ -41,6 +42,13 @@ op::SpaceToDepth::SpaceToDepth(const Output<Node>& data, const std::string& mode ...@@ -41,6 +42,13 @@ op::SpaceToDepth::SpaceToDepth(const Output<Node>& data, const std::string& mode
{ {
} }
bool ngraph::op::v0::SpaceToDepth::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("block_size", m_blocksize);
visitor.on_attribute("mode", m_mode);
return true;
}
NodeVector op::SpaceToDepth::decompose_op() const NodeVector op::SpaceToDepth::decompose_op() const
{ {
auto data = input_value(0); auto data = input_value(0);
...@@ -153,3 +161,24 @@ op::SpaceToDepth::SpaceToDepthMode op::SpaceToDepth::mode_from_string(const std: ...@@ -153,3 +161,24 @@ op::SpaceToDepth::SpaceToDepthMode op::SpaceToDepth::mode_from_string(const std:
return allowed_values.at(mode); return allowed_values.at(mode);
} }
namespace ngraph
{
template <>
EnumNames<op::v0::SpaceToDepth::SpaceToDepthMode>&
EnumNames<op::v0::SpaceToDepth::SpaceToDepthMode>::get()
{
static auto enum_names = EnumNames<op::v0::SpaceToDepth::SpaceToDepthMode>(
"op::v0::SpaceToDepth::SpaceToDepthMode",
{{"blocks_first", op::v0::SpaceToDepth::SpaceToDepthMode::BLOCKS_FIRST},
{"depth_first", op::v0::SpaceToDepth::SpaceToDepthMode::DEPTH_FIRST}});
return enum_names;
}
constexpr DiscreteTypeInfo AttributeAdapter<op::v0::SpaceToDepth::SpaceToDepthMode>::type_info;
std::ostream& operator<<(std::ostream& s, const op::v0::SpaceToDepth::SpaceToDepthMode& type)
{
return s << as_string(type);
}
} // namespace ngraph
...@@ -60,6 +60,7 @@ namespace ngraph ...@@ -60,6 +60,7 @@ namespace ngraph
const std::string& mode, const std::string& mode,
std::size_t block_size = 1); std::size_t block_size = 1);
bool visit_attributes(AttributeVisitor& visitor) override;
std::size_t get_block_size() const { return m_blocksize; } std::size_t get_block_size() const { return m_blocksize; }
SpaceToDepthMode get_mode() const { return m_mode; } SpaceToDepthMode get_mode() const { return m_mode; }
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
...@@ -74,5 +75,22 @@ namespace ngraph ...@@ -74,5 +75,22 @@ namespace ngraph
}; };
} }
using v0::SpaceToDepth; using v0::SpaceToDepth;
} } // namespace op
}
std::ostream& operator<<(std::ostream& s, const op::v0::SpaceToDepth::SpaceToDepthMode& type);
template <>
class NGRAPH_API AttributeAdapter<op::v0::SpaceToDepth::SpaceToDepthMode>
: public EnumAttributeAdapterBase<op::v0::SpaceToDepth::SpaceToDepthMode>
{
public:
AttributeAdapter(op::v0::SpaceToDepth::SpaceToDepthMode& value)
: EnumAttributeAdapterBase<op::v0::SpaceToDepth::SpaceToDepthMode>(value)
{
}
static constexpr DiscreteTypeInfo type_info{
"AttributeAdapter<op::v0::SpaceToDepth::SpaceToDepthMode>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
} // namespace ngraph
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include <numeric> #include <numeric>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/split.hpp" #include "ngraph/builder/split.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/split.hpp" #include "ngraph/op/fused/split.hpp"
...@@ -121,6 +122,12 @@ op::v1::Split::Split(const Output<Node>& data, const Output<Node>& axis, const s ...@@ -121,6 +122,12 @@ op::v1::Split::Split(const Output<Node>& data, const Output<Node>& axis, const s
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v1::Split::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("num_splits", m_num_splits);
return true;
}
void op::v1::Split::validate_and_infer_types() void op::v1::Split::validate_and_infer_types()
{ {
const auto data_ps = input_value(0).get_partial_shape(); const auto data_ps = input_value(0).get_partial_shape();
......
...@@ -95,6 +95,7 @@ namespace ngraph ...@@ -95,6 +95,7 @@ namespace ngraph
/// split into. /// split into.
Split(const Output<Node>& data, const Output<Node>& axis, const size_t num_splits); Split(const Output<Node>& data, const Output<Node>& axis, const size_t num_splits);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/op/fused/squared_difference.hpp" #include "ngraph/op/fused/squared_difference.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/subtract.hpp" #include "ngraph/op/subtract.hpp"
...@@ -34,6 +35,12 @@ op::SquaredDifference::SquaredDifference(const Output<Node>& x1, ...@@ -34,6 +35,12 @@ op::SquaredDifference::SquaredDifference(const Output<Node>& x1,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v0::SquaredDifference::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("auto_broadcast", m_autobroadcast);
return true;
}
NodeVector op::SquaredDifference::decompose_op() const NodeVector op::SquaredDifference::decompose_op() const
{ {
const auto x1 = input_value(0); const auto x1 = input_value(0);
......
...@@ -45,6 +45,7 @@ namespace ngraph ...@@ -45,6 +45,7 @@ namespace ngraph
const Output<Node>& x2, const Output<Node>& x2,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastType::NUMPY); const AutoBroadcastSpec& auto_broadcast = AutoBroadcastType::NUMPY);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -90,6 +90,11 @@ void op::Squeeze::pre_validate_and_infer_types() ...@@ -90,6 +90,11 @@ void op::Squeeze::pre_validate_and_infer_types()
set_output_type(0, get_input_element_type(0), output_data_shape); set_output_type(0, get_input_element_type(0), output_data_shape);
} }
bool ngraph::op::v0::Squeeze::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
NodeVector op::Squeeze::decompose_op() const NodeVector op::Squeeze::decompose_op() const
{ {
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(
......
...@@ -37,6 +37,7 @@ namespace ngraph ...@@ -37,6 +37,7 @@ namespace ngraph
Squeeze() = default; Squeeze() = default;
Squeeze(const Output<Node>& data, const Output<Node>& axes); Squeeze(const Output<Node>& data, const Output<Node>& axes);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
virtual void pre_validate_and_infer_types() override; virtual void pre_validate_and_infer_types() override;
......
...@@ -89,6 +89,11 @@ NodeVector op::Unsqueeze::decompose_op() const ...@@ -89,6 +89,11 @@ NodeVector op::Unsqueeze::decompose_op() const
return {make_shared<ngraph::op::Reshape>(data, input_order, output_shape)}; return {make_shared<ngraph::op::Reshape>(data, input_order, output_shape)};
} }
bool ngraph::op::v0::Unsqueeze::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
shared_ptr<Node> op::Unsqueeze::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Unsqueeze::copy_with_new_args(const NodeVector& new_args) const
{ {
if (new_args.size() != 2) if (new_args.size() != 2)
......
...@@ -40,6 +40,8 @@ namespace ngraph ...@@ -40,6 +40,8 @@ namespace ngraph
virtual void pre_validate_and_infer_types() override; virtual void pre_validate_and_infer_types() override;
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
}; };
......
...@@ -27,6 +27,11 @@ op::Sign::Sign(const Output<Node>& arg) ...@@ -27,6 +27,11 @@ op::Sign::Sign(const Output<Node>& arg)
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v0::Sign::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
shared_ptr<Node> op::Sign::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Sign::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -37,6 +37,7 @@ namespace ngraph ...@@ -37,6 +37,7 @@ namespace ngraph
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Sign(const Output<Node>& arg); Sign(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
}; };
......
...@@ -29,6 +29,11 @@ op::Sin::Sin(const Output<Node>& arg) ...@@ -29,6 +29,11 @@ op::Sin::Sin(const Output<Node>& arg)
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v0::Sin::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
shared_ptr<Node> op::Sin::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Sin::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -50,6 +50,7 @@ namespace ngraph ...@@ -50,6 +50,7 @@ namespace ngraph
Sin(const Output<Node>& arg); Sin(const Output<Node>& arg);
Sin() = default; Sin() = default;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -29,6 +29,11 @@ op::Sinh::Sinh(const Output<Node>& arg) ...@@ -29,6 +29,11 @@ op::Sinh::Sinh(const Output<Node>& arg)
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v0::Sinh::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
shared_ptr<Node> op::Sinh::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Sinh::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -36,6 +36,7 @@ namespace ngraph ...@@ -36,6 +36,7 @@ namespace ngraph
Sinh(const Output<Node>& arg); Sinh(const Output<Node>& arg);
Sinh() = default; Sinh() = default;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <algorithm> #include <algorithm>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/autobroadcast.hpp" #include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
...@@ -164,6 +165,12 @@ op::v1::Softmax::Softmax(const Output<Node>& arg, const size_t axis) ...@@ -164,6 +165,12 @@ op::v1::Softmax::Softmax(const Output<Node>& arg, const size_t axis)
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v1::Softmax::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("axis", m_axis);
return true;
}
void op::v1::Softmax::validate_and_infer_types() void op::v1::Softmax::validate_and_infer_types()
{ {
const PartialShape& input_shape = get_input_partial_shape(0); const PartialShape& input_shape = get_input_partial_shape(0);
......
...@@ -88,6 +88,7 @@ namespace ngraph ...@@ -88,6 +88,7 @@ namespace ngraph
/// ///
Softmax(const Output<Node>& arg, const size_t axis); Softmax(const Output<Node>& arg, const size_t axis);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
size_t get_version() const override { return 1; } size_t get_version() const override { return 1; }
......
...@@ -29,6 +29,11 @@ op::Sqrt::Sqrt(const Output<Node>& arg) ...@@ -29,6 +29,11 @@ op::Sqrt::Sqrt(const Output<Node>& arg)
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v0::Sqrt::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
shared_ptr<Node> op::Sqrt::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Sqrt::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -50,6 +50,7 @@ namespace ngraph ...@@ -50,6 +50,7 @@ namespace ngraph
Sqrt(const Output<Node>& arg); Sqrt(const Output<Node>& arg);
Sqrt() = default; Sqrt() = default;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/op/strided_slice.hpp" #include "ngraph/op/strided_slice.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/validation_util.hpp" #include "ngraph/validation_util.hpp"
...@@ -66,6 +67,16 @@ op::v1::StridedSlice::StridedSlice(const Output<Node>& data, ...@@ -66,6 +67,16 @@ op::v1::StridedSlice::StridedSlice(const Output<Node>& data,
{ {
} }
bool ngraph::op::v1::StridedSlice::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("begin_mask", m_begin_mask);
visitor.on_attribute("end_mask", m_end_mask);
visitor.on_attribute("new_axis_mask", m_new_axis_mask);
visitor.on_attribute("shrink_axis_mask", m_shrink_axis_mask);
visitor.on_attribute("ellipsis_mask", m_ellipsis_mask);
return true;
}
void op::v1::StridedSlice::validate_and_infer_types() void op::v1::StridedSlice::validate_and_infer_types()
{ {
const auto& begin_mask_et = get_input_element_type(1); const auto& begin_mask_et = get_input_element_type(1);
......
...@@ -90,6 +90,7 @@ namespace ngraph ...@@ -90,6 +90,7 @@ namespace ngraph
const std::vector<int64_t>& shrink_axis_mask = std::vector<int64_t>{}, const std::vector<int64_t>& shrink_axis_mask = std::vector<int64_t>{},
const std::vector<int64_t>& ellipsis_mask = std::vector<int64_t>{}); const std::vector<int64_t>& ellipsis_mask = std::vector<int64_t>{});
bool visit_attributes(AttributeVisitor& visitor) override;
const std::vector<int64_t>& get_begin_mask() const { return m_begin_mask; } const std::vector<int64_t>& get_begin_mask() const { return m_begin_mask; }
const std::vector<int64_t>& get_end_mask() const { return m_end_mask; } const std::vector<int64_t>& get_end_mask() const { return m_end_mask; }
const std::vector<int64_t>& get_new_axis_mask() const { return m_new_axis_mask; } const std::vector<int64_t>& get_new_axis_mask() const { return m_new_axis_mask; }
......
...@@ -30,6 +30,11 @@ op::Tan::Tan(const Output<Node>& arg) ...@@ -30,6 +30,11 @@ op::Tan::Tan(const Output<Node>& arg)
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v0::Tan::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
shared_ptr<Node> op::Tan::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Tan::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -50,6 +50,7 @@ namespace ngraph ...@@ -50,6 +50,7 @@ namespace ngraph
Tan(const Output<Node>& arg); Tan(const Output<Node>& arg);
Tan() = default; Tan() = default;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -29,6 +29,11 @@ op::Tanh::Tanh(const Output<Node>& arg) ...@@ -29,6 +29,11 @@ op::Tanh::Tanh(const Output<Node>& arg)
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v0::Tanh::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
shared_ptr<Node> op::Tanh::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Tanh::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -36,6 +36,7 @@ namespace ngraph ...@@ -36,6 +36,7 @@ namespace ngraph
Tanh(const Output<Node>& arg); Tanh(const Output<Node>& arg);
Tanh() = default; Tanh() = default;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <memory> #include <memory>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/axis_vector.hpp" #include "ngraph/axis_vector.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/topk.hpp" #include "ngraph/op/topk.hpp"
...@@ -262,6 +263,14 @@ op::v1::TopK::TopK(const Output<Node>& data, ...@@ -262,6 +263,14 @@ op::v1::TopK::TopK(const Output<Node>& data,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v1::TopK::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("axis", m_axis);
visitor.on_attribute("mode", m_mode);
visitor.on_attribute("sort", m_sort);
return true;
}
void op::v1::TopK::validate_and_infer_types() void op::v1::TopK::validate_and_infer_types()
{ {
const auto& input_partial_shape = get_input_partial_shape(0); const auto& input_partial_shape = get_input_partial_shape(0);
...@@ -434,3 +443,22 @@ void op::v1::TopK::set_k(size_t k) ...@@ -434,3 +443,22 @@ void op::v1::TopK::set_k(size_t k)
this->input(1).replace_source_output( this->input(1).replace_source_output(
op::Constant::create(element::i64, Shape{}, {k})->output(0)); op::Constant::create(element::i64, Shape{}, {k})->output(0));
} }
namespace ngraph
{
template <>
EnumNames<op::v1::TopK::Mode>& EnumNames<op::v1::TopK::Mode>::get()
{
static auto enum_names = EnumNames<op::v1::TopK::Mode>(
"op::v1::TopK::Mode",
{{"max", op::v1::TopK::Mode::MAX}, {"min", op::v1::TopK::Mode::MIN}});
return enum_names;
}
constexpr DiscreteTypeInfo AttributeAdapter<op::v1::TopK::Mode>::type_info;
std::ostream& operator<<(std::ostream& s, const op::v1::TopK::Mode& type)
{
return s << as_string(type);
}
}
...@@ -155,6 +155,7 @@ namespace ngraph ...@@ -155,6 +155,7 @@ namespace ngraph
const SortType sort, const SortType sort,
const element::Type& index_element_type = element::i32); const element::Type& index_element_type = element::i32);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
...@@ -208,4 +209,20 @@ namespace ngraph ...@@ -208,4 +209,20 @@ namespace ngraph
using v0::TopK; using v0::TopK;
} // op } // op
std::ostream& operator<<(std::ostream& s, const op::v1::TopK::Mode& type);
template <>
class NGRAPH_API AttributeAdapter<op::v1::TopK::Mode>
: public EnumAttributeAdapterBase<op::v1::TopK::Mode>
{
public:
AttributeAdapter(op::v1::TopK::Mode& value)
: EnumAttributeAdapterBase<op::v1::TopK::Mode>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::v1::TopK::Mode>", 1};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
} // ngraph } // ngraph
...@@ -59,6 +59,6 @@ void op::util::BinaryElementwiseArithmetic::validate_and_infer_types() ...@@ -59,6 +59,6 @@ void op::util::BinaryElementwiseArithmetic::validate_and_infer_types()
bool op::util::BinaryElementwiseArithmetic::visit_attributes(AttributeVisitor& visitor) bool op::util::BinaryElementwiseArithmetic::visit_attributes(AttributeVisitor& visitor)
{ {
visitor.on_attribute("autob", m_autob); visitor.on_attribute("auto_broadcast", m_autob);
return true; return true;
} }
...@@ -56,6 +56,6 @@ void op::util::BinaryElementwiseLogical::validate_and_infer_types() ...@@ -56,6 +56,6 @@ void op::util::BinaryElementwiseLogical::validate_and_infer_types()
bool op::util::BinaryElementwiseLogical::visit_attributes(AttributeVisitor& visitor) bool op::util::BinaryElementwiseLogical::visit_attributes(AttributeVisitor& visitor)
{ {
visitor.on_attribute("autob", m_autob); visitor.on_attribute("auto_broadcast", m_autob);
return true; return true;
} }
...@@ -33,6 +33,11 @@ op::v1::VariadicSplit::VariadicSplit(const Output<Node>& data, ...@@ -33,6 +33,11 @@ op::v1::VariadicSplit::VariadicSplit(const Output<Node>& data,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v1::VariadicSplit::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
void ngraph::op::v1::VariadicSplit::validate_and_infer_types() void ngraph::op::v1::VariadicSplit::validate_and_infer_types()
{ {
set_input_is_relevant_to_value(0); set_input_is_relevant_to_value(0);
......
...@@ -48,6 +48,8 @@ namespace ngraph ...@@ -48,6 +48,8 @@ namespace ngraph
const Output<Node>& axis, const Output<Node>& axis,
const Output<Node>& split_lengths); const Output<Node>& split_lengths);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -35,6 +35,12 @@ shared_ptr<Node> op::v1::LogicalXor::copy_with_new_args(const NodeVector& new_ar ...@@ -35,6 +35,12 @@ shared_ptr<Node> op::v1::LogicalXor::copy_with_new_args(const NodeVector& new_ar
return make_shared<v1::LogicalXor>(new_args.at(0), new_args.at(1), this->get_autob()); return make_shared<v1::LogicalXor>(new_args.at(0), new_args.at(1), this->get_autob());
} }
bool ngraph::op::v1::LogicalXor::visit_attributes(AttributeVisitor& visitor)
{
BinaryElementwiseLogical::visit_attributes(visitor);
return true;
}
constexpr NodeTypeInfo op::v0::Xor::type_info; constexpr NodeTypeInfo op::v0::Xor::type_info;
op::v0::Xor::Xor(const Output<Node>& arg0, op::v0::Xor::Xor(const Output<Node>& arg0,
......
...@@ -53,6 +53,7 @@ namespace ngraph ...@@ -53,6 +53,7 @@ namespace ngraph
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
virtual bool is_commutative() const override { return true; } virtual bool is_commutative() const override { return true; }
bool visit_attributes(AttributeVisitor& visitor) override;
}; };
} // namespace v1 } // namespace v1
namespace v0 namespace v0
......
...@@ -498,3 +498,136 @@ TEST(attributes, lstm_sequence_op) ...@@ -498,3 +498,136 @@ TEST(attributes, lstm_sequence_op)
EXPECT_EQ(g_lstm_sequence->get_input_forget(), lstm_sequence->get_input_forget()); EXPECT_EQ(g_lstm_sequence->get_input_forget(), lstm_sequence->get_input_forget());
EXPECT_EQ(g_lstm_sequence->get_weights_format(), lstm_sequence->get_weights_format()); EXPECT_EQ(g_lstm_sequence->get_weights_format(), lstm_sequence->get_weights_format());
} }
TEST(attributes, shuffle_channels_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::ShuffleChannels>();
auto data = make_shared<op::Parameter>(element::i32, Shape{200});
auto axis = 0;
auto groups = 2;
auto shuffle_channels = make_shared<opset1::ShuffleChannels>(data, axis, groups);
NodeBuilder builder(shuffle_channels);
auto g_shuffle_channels = as_type_ptr<opset1::ShuffleChannels>(builder.create());
EXPECT_EQ(g_shuffle_channels->get_axis(), shuffle_channels->get_axis());
EXPECT_EQ(g_shuffle_channels->get_groups(), shuffle_channels->get_groups());
}
TEST(attributes, softmax_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::Softmax>();
auto data = make_shared<op::Parameter>(element::i32, Shape{200});
auto axis = 0;
auto softmax = make_shared<opset1::Softmax>(data, axis);
NodeBuilder builder(softmax);
auto g_softmax = as_type_ptr<opset1::Softmax>(builder.create());
EXPECT_EQ(g_softmax->get_axis(), softmax->get_axis());
}
TEST(attributes, space_to_depth_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::SpaceToDepth>();
auto data = make_shared<op::Parameter>(element::i32, Shape{2, 3, 50, 50});
auto block_size = 2;
auto mode = opset1::SpaceToDepth::SpaceToDepthMode::BLOCKS_FIRST;
auto space_to_depth = make_shared<opset1::SpaceToDepth>(data, mode, block_size);
NodeBuilder builder(space_to_depth);
auto g_space_to_depth = as_type_ptr<opset1::SpaceToDepth>(builder.create());
EXPECT_EQ(g_space_to_depth->get_block_size(), space_to_depth->get_block_size());
EXPECT_EQ(g_space_to_depth->get_mode(), space_to_depth->get_mode());
}
TEST(attributes, split_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::Split>();
auto data = make_shared<op::Parameter>(element::i32, Shape{200});
auto axis = make_shared<op::Parameter>(element::i32, Shape{});
auto num_splits = 2;
auto split = make_shared<opset1::Split>(data, axis, num_splits);
NodeBuilder builder(split);
auto g_split = as_type_ptr<opset1::Split>(builder.create());
EXPECT_EQ(g_split->get_num_splits(), split->get_num_splits());
}
TEST(attributes, squared_difference_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::SquaredDifference>();
auto x1 = make_shared<op::Parameter>(element::i32, Shape{200});
auto x2 = make_shared<op::Parameter>(element::i32, Shape{200});
auto auto_broadcast = op::AutoBroadcastType::NUMPY;
auto squared_difference = make_shared<opset1::SquaredDifference>(x1, x2, auto_broadcast);
NodeBuilder builder(squared_difference);
auto g_squared_difference = as_type_ptr<opset1::SquaredDifference>(builder.create());
EXPECT_EQ(g_squared_difference->get_autob(), squared_difference->get_autob());
}
TEST(attributes, strided_slice_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::StridedSlice>();
auto data = make_shared<op::Parameter>(element::i32, Shape{2, 3, 4, 5});
auto begin = make_shared<op::Parameter>(element::i32, Shape{2});
auto end = make_shared<op::Parameter>(element::i32, Shape{2});
auto stride = make_shared<op::Parameter>(element::i32, Shape{2});
auto begin_mask = std::vector<int64_t>{0, 0};
auto end_mask = std::vector<int64_t>{0, 0};
auto new_axis_mask = std::vector<int64_t>{0, 0};
auto shrink_axis_mask = std::vector<int64_t>{0, 0};
auto ellipsis_mask = std::vector<int64_t>{0, 0};
auto strided_slice = make_shared<opset1::StridedSlice>(data,
begin,
end,
stride,
begin_mask,
end_mask,
new_axis_mask,
shrink_axis_mask,
ellipsis_mask);
NodeBuilder builder(strided_slice);
auto g_strided_slice = as_type_ptr<opset1::StridedSlice>(builder.create());
EXPECT_EQ(g_strided_slice->get_begin_mask(), strided_slice->get_begin_mask());
EXPECT_EQ(g_strided_slice->get_end_mask(), strided_slice->get_end_mask());
EXPECT_EQ(g_strided_slice->get_new_axis_mask(), strided_slice->get_new_axis_mask());
EXPECT_EQ(g_strided_slice->get_shrink_axis_mask(), strided_slice->get_shrink_axis_mask());
EXPECT_EQ(g_strided_slice->get_ellipsis_mask(), strided_slice->get_ellipsis_mask());
}
TEST(attributes, topk_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::TopK>();
auto data = make_shared<op::Parameter>(element::i32, Shape{2, 3, 4, 5});
auto k = make_shared<op::Parameter>(element::i32, Shape{});
auto axis = 0;
auto mode = opset1::TopK::Mode::MAX;
auto sort_type = opset1::TopK::SortType::SORT_VALUES;
auto topk = make_shared<opset1::TopK>(data, k, axis, mode, sort_type);
NodeBuilder builder(topk);
auto g_topk = as_type_ptr<opset1::TopK>(builder.create());
EXPECT_EQ(g_topk->get_axis(), topk->get_axis());
EXPECT_EQ(g_topk->get_mode(), topk->get_mode());
EXPECT_EQ(g_topk->get_sort_type(), topk->get_sort_type());
}
TEST(attributes, logical_xor_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::LogicalXor>();
auto x1 = make_shared<op::Parameter>(element::boolean, Shape{200});
auto x2 = make_shared<op::Parameter>(element::boolean, Shape{200});
auto auto_broadcast = op::AutoBroadcastType::NUMPY;
auto logical_xor = make_shared<opset1::LogicalXor>(x1, x2, auto_broadcast);
NodeBuilder builder(logical_xor);
auto g_logical_xor = as_type_ptr<opset1::LogicalXor>(builder.create());
EXPECT_EQ(g_logical_xor->get_autob(), logical_xor->get_autob());
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment