Unverified Commit 5d8c39e9 authored by Tomasz Socha's avatar Tomasz Socha Committed by GitHub

Add attribute visitor for ops R (#4340)

parent 1b611294
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include "region_yolo.hpp"
#include "ngraph/attribute_visitor.hpp"
using namespace std;
using namespace ngraph;
......@@ -43,6 +44,19 @@ op::RegionYolo::RegionYolo(const Output<Node>& input,
constructor_validate_and_infer_types();
}
bool ngraph::op::v0::RegionYolo::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("anchors", m_anchors);
visitor.on_attribute("axis", m_axis);
visitor.on_attribute("coords", m_num_coords);
visitor.on_attribute("classes", m_num_classes);
visitor.on_attribute("end_axis", m_end_axis);
visitor.on_attribute("num", m_num_regions);
visitor.on_attribute("do_softmax", m_do_softmax);
visitor.on_attribute("mask", m_mask);
return true;
}
void op::RegionYolo::validate_and_infer_types()
{
auto input_et = get_input_element_type(0);
......
......@@ -55,6 +55,7 @@ namespace ngraph
const int end_axis,
const std::vector<float>& anchors = std::vector<float>{});
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
......
......@@ -182,6 +182,11 @@ static PartialShape infer_output_shape(const op::Range* node, const element::Typ
return result;
}
bool ngraph::op::v0::Range::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
void op::Range::validate_and_infer_types()
{
set_input_is_relevant_to_shape(0);
......
......@@ -46,6 +46,7 @@ namespace ngraph
const Output<Node>& stop,
const Output<Node>& step);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
......
......@@ -65,6 +65,11 @@ op::RNNCell::RNNCell(const Output<Node>& X,
constructor_validate_and_infer_types();
}
bool op::RNNCell::visit_attributes(AttributeVisitor& visitor)
{
return op::util::RNNCellBase::visit_attributes(visitor);
}
void op::RNNCell::pre_validate_and_infer_types()
{
if (is_dynamic())
......
......@@ -132,6 +132,7 @@ namespace ngraph
const std::vector<float>& activations_beta = {},
float clip = 0.f);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual void pre_validate_and_infer_types() override;
virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node>
......
......@@ -156,6 +156,12 @@ op::v1::Reshape::Reshape(const Output<Node>& arg, const Output<Node>& pattern, b
constructor_validate_and_infer_types();
}
bool op::v1::Reshape::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("special_zero", m_special_zero);
return true;
}
void op::v1::Reshape::validate_and_infer_types()
{
auto pattern_et = get_input_element_type(1);
......
......@@ -137,6 +137,7 @@ namespace ngraph
/// from input shape at the same index.
Reshape(const Output<Node>& arg, const Output<Node>& pattern, bool special_zero);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
size_t get_version() const override { return 1; }
......
......@@ -35,6 +35,11 @@ op::Result::Result(const Output<Node>& arg, bool needs_default_layout)
set_placement_index(input_value(0).get_node()->get_placement_index());
}
bool ngraph::op::v0::Result::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
void op::Result::validate_and_infer_types()
{
NODE_VALIDATION_CHECK(
......
......@@ -38,6 +38,7 @@ namespace ngraph
/// \param arg Node that produces the input tensor.
Result(const Output<Node>& arg, bool needs_default_layout = false);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
......
......@@ -17,6 +17,7 @@
#include <algorithm>
#include <sstream>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/function.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/reverse.hpp"
......@@ -91,6 +92,12 @@ op::v1::Reverse::Reverse(const Output<Node>& data,
constructor_validate_and_infer_types();
}
bool ngraph::op::v1::Reverse::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("mode", m_mode);
return true;
}
void op::v1::Reverse::validate_and_infer_types()
{
if (m_mode == Mode::MASK)
......@@ -194,3 +201,22 @@ op::v1::Reverse::Mode op::v1::Reverse::mode_from_string(const std::string& mode)
return allowed_values.at(mode);
}
namespace ngraph
{
template <>
EnumNames<op::v1::Reverse::Mode>& EnumNames<op::v1::Reverse::Mode>::get()
{
static auto enum_names = EnumNames<op::v1::Reverse::Mode>(
"op::v1::Reverse::Mode",
{{"index", op::v1::Reverse::Mode::INDEX}, {"mask", op::v1::Reverse::Mode::MASK}});
return enum_names;
}
constexpr DiscreteTypeInfo AttributeAdapter<op::v1::Reverse::Mode>::type_info;
std::ostream& operator<<(std::ostream& s, const op::v1::Reverse::Mode& type)
{
return s << as_string(type);
}
} // namespace ngraph
......@@ -108,6 +108,7 @@ namespace ngraph
const Output<Node>& reversed_axes,
const Mode mode);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
......@@ -135,4 +136,20 @@ namespace ngraph
// default opset version
using v0::Reverse;
}
std::ostream& operator<<(std::ostream& s, const op::v1::Reverse::Mode& type);
template <>
class NGRAPH_API AttributeAdapter<op::v1::Reverse::Mode>
: public EnumAttributeAdapterBase<op::v1::Reverse::Mode>
{
public:
AttributeAdapter(op::v1::Reverse::Mode& value)
: EnumAttributeAdapterBase<op::v1::Reverse::Mode>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::v1::Reverse::Mode>", 1};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
}
......@@ -17,6 +17,7 @@
#include <algorithm>
#include <memory>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/validation_util.hpp"
......@@ -39,6 +40,13 @@ op::ReverseSequence::ReverseSequence(const Output<Node>& arg,
constructor_validate_and_infer_types();
}
bool ngraph::op::v0::ReverseSequence::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("batch_axis", m_batch_axis);
visitor.on_attribute("seq_axis", m_seq_axis);
return true;
}
void op::ReverseSequence::validate_and_infer_types()
{
auto input_shape = get_input_partial_shape(0);
......
......@@ -38,6 +38,7 @@ namespace ngraph
int64_t batch_axis,
int64_t seq_axis);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
......
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/util/arithmetic_reductions_keep_dims.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/validation_util.hpp"
......@@ -30,6 +31,12 @@ op::util::ArithmeticReductionKeepDims::ArithmeticReductionKeepDims(
{
}
bool ngraph::op::util::ArithmeticReductionKeepDims::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("keep_dims", m_keep_dims);
return true;
}
void op::util::ArithmeticReductionKeepDims::validate_and_infer_types()
{
if (m_keep_dims)
......
......@@ -37,6 +37,8 @@ namespace ngraph
const Output<Node>& reduction_axes,
bool keep_dims = false);
bool visit_attributes(AttributeVisitor& visitor) override;
public:
void validate_and_infer_types() override;
......
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/util/logical_reduction_keep_dims.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/validation_util.hpp"
......@@ -30,6 +31,12 @@ op::util::LogicalReductionKeepDims::LogicalReductionKeepDims(
{
}
bool ngraph::op::util::LogicalReductionKeepDims::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("keep_dims", m_keep_dims);
return true;
}
void op::util::LogicalReductionKeepDims::validate_and_infer_types()
{
if (m_keep_dims)
......
......@@ -37,6 +37,8 @@ namespace ngraph
const Output<Node>& reduction_axes,
const bool keep_dims = false);
bool visit_attributes(AttributeVisitor& visitor) override;
public:
void validate_and_infer_types() override;
......
......@@ -17,6 +17,7 @@
#include <algorithm>
#include <iterator>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/fused/clamp.hpp"
#include "ngraph/op/multiply.hpp"
......@@ -48,6 +49,16 @@ op::util::RNNCellBase::RNNCellBase(size_t hidden_size,
{
}
bool ngraph::op::util::RNNCellBase::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("hidden_size", m_hidden_size);
visitor.on_attribute("activations", m_activations);
visitor.on_attribute("activations_alpha", m_activations_alpha);
visitor.on_attribute("activations_beta", m_activations_beta);
visitor.on_attribute("clip", m_clip);
return true;
}
op::util::ActivationFunction op::util::RNNCellBase::get_activation_function(size_t idx) const
{
op::util::ActivationFunction afunc = get_activation_func_by_name(m_activations.at(idx));
......
......@@ -58,6 +58,7 @@ namespace ngraph
RNNCellBase() = default;
virtual bool visit_attributes(AttributeVisitor& visitor);
std::size_t get_hidden_size() const { return m_hidden_size; }
float get_clip() const { return m_clip; }
const std::vector<std::string>& get_activations() const { return m_activations; }
......
......@@ -43,3 +43,8 @@ void op::util::UnaryElementwiseArithmetic::validate_and_infer_types()
{
validate_and_infer_elementwise_arithmetic();
}
bool op::util::UnaryElementwiseArithmetic::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
......@@ -68,6 +68,7 @@ namespace ngraph
public:
void validate_and_infer_types() override;
bool is_unary_elementwise_arithmetic() const override { return true; }
bool visit_attributes(AttributeVisitor& visitor) override;
};
}
}
......
......@@ -285,6 +285,259 @@ TEST(attributes, user_op)
EXPECT_EQ(g_oracle->get_ultra_parameters(), oracle->get_ultra_parameters());
}
TEST(attributes, reduce_logical_and_op)
{
// ReduceLogicalAnd derives visit_attributes from op::util::LogicalReductionKeepDims
FactoryRegistry<Node>::get().register_factory<opset1::ReduceLogicalAnd>();
auto data = make_shared<op::Parameter>(element::f32, Shape{3, 4, 5});
auto reduction_axes = make_shared<op::Parameter>(element::i64, Shape{2});
bool keep_dims = true;
auto reduce_logical_and = make_shared<opset1::ReduceSum>(data, reduction_axes, keep_dims);
NodeBuilder builder(reduce_logical_and);
auto g_reduce_logical_and = as_type_ptr<opset1::ReduceSum>(builder.create());
EXPECT_EQ(g_reduce_logical_and->get_keep_dims(), reduce_logical_and->get_keep_dims());
}
TEST(attributes, reduce_logical_or_op)
{
// ReduceLogicalOr derives visit_attributes from op::util::LogicalReductionKeepDims
FactoryRegistry<Node>::get().register_factory<opset1::ReduceLogicalOr>();
auto data = make_shared<op::Parameter>(element::f32, Shape{3, 4, 5});
auto reduction_axes = make_shared<op::Parameter>(element::i64, Shape{2});
bool keep_dims = true;
auto reduce_logical_or = make_shared<opset1::ReduceLogicalOr>(data, reduction_axes, keep_dims);
NodeBuilder builder(reduce_logical_or);
auto g_reduce_logical_or = as_type_ptr<opset1::ReduceLogicalOr>(builder.create());
EXPECT_EQ(g_reduce_logical_or->get_keep_dims(), reduce_logical_or->get_keep_dims());
}
TEST(attributes, reduce_max_op)
{
// ReduceMax derives visit_attributes from op::util::ArithmeticReductionKeepDims
FactoryRegistry<Node>::get().register_factory<opset1::ReduceMax>();
auto data = make_shared<op::Parameter>(element::f32, Shape{3, 4, 5});
auto reduction_axes = make_shared<op::Parameter>(element::i64, Shape{2});
bool keep_dims = true;
auto reduce_max = make_shared<opset1::ReduceMax>(data, reduction_axes, keep_dims);
NodeBuilder builder(reduce_max);
auto g_reduce_max = as_type_ptr<opset1::ReduceMax>(builder.create());
EXPECT_EQ(g_reduce_max->get_keep_dims(), reduce_max->get_keep_dims());
}
TEST(attributes, reduce_mean_op)
{
// ReduceMean derives visit_attributes from op::util::ArithmeticReductionKeepDims
FactoryRegistry<Node>::get().register_factory<opset1::ReduceMean>();
auto data = make_shared<op::Parameter>(element::f32, Shape{3, 4, 5});
auto reduction_axes = make_shared<op::Parameter>(element::i64, Shape{2});
bool keep_dims = true;
auto reduce_mean = make_shared<opset1::ReduceMean>(data, reduction_axes, keep_dims);
NodeBuilder builder(reduce_mean);
auto g_reduce_mean = as_type_ptr<opset1::ReduceMean>(builder.create());
EXPECT_EQ(g_reduce_mean->get_keep_dims(), reduce_mean->get_keep_dims());
}
TEST(attributes, reduce_min_op)
{
// ReduceMin derives visit_attributes from op::util::ArithmeticReductionKeepDims
FactoryRegistry<Node>::get().register_factory<opset1::ReduceMin>();
auto data = make_shared<op::Parameter>(element::f32, Shape{3, 4, 5});
auto reduction_axes = make_shared<op::Parameter>(element::i64, Shape{2});
bool keep_dims = true;
auto reduce_min = make_shared<opset1::ReduceMin>(data, reduction_axes, keep_dims);
NodeBuilder builder(reduce_min);
auto g_reduce_min = as_type_ptr<opset1::ReduceMin>(builder.create());
EXPECT_EQ(g_reduce_min->get_keep_dims(), reduce_min->get_keep_dims());
}
TEST(attributes, reduce_prod_op)
{
// ReduceProd derives visit_attributes from op::util::ArithmeticReductionKeepDims
FactoryRegistry<Node>::get().register_factory<opset1::ReduceProd>();
auto data = make_shared<op::Parameter>(element::f32, Shape{3, 4, 5});
auto reduction_axes = make_shared<op::Parameter>(element::i64, Shape{2});
bool keep_dims = true;
auto reduce_prod = make_shared<opset1::ReduceProd>(data, reduction_axes, keep_dims);
NodeBuilder builder(reduce_prod);
auto g_reduce_prod = as_type_ptr<opset1::ReduceProd>(builder.create());
EXPECT_EQ(g_reduce_prod->get_keep_dims(), reduce_prod->get_keep_dims());
}
TEST(attributes, reduce_sum_op)
{
// ReduceSum derives visit_attributes from op::util::ArithmeticReductionKeepDims
FactoryRegistry<Node>::get().register_factory<opset1::ReduceSum>();
auto data = make_shared<op::Parameter>(element::f32, Shape{3, 4, 5});
auto reduction_axes = make_shared<op::Parameter>(element::i64, Shape{2});
bool keep_dims = true;
auto reduce_sum = make_shared<opset1::ReduceSum>(data, reduction_axes, keep_dims);
NodeBuilder builder(reduce_sum);
auto g_reduce_sum = as_type_ptr<opset1::ReduceSum>(builder.create());
EXPECT_EQ(g_reduce_sum->get_keep_dims(), reduce_sum->get_keep_dims());
}
TEST(attributes, region_yolo_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::RegionYolo>();
auto data = make_shared<op::Parameter>(element::i64, Shape{1, 255, 26, 26});
size_t num_coords = 4;
size_t num_classes = 1;
size_t num_regions = 6;
auto do_softmax = false;
auto mask = std::vector<int64_t>{0, 1};
auto axis = 1;
auto end_axis = 3;
auto anchors = std::vector<float>{10, 14, 23, 27, 37, 58, 81, 82, 135, 169, 344, 319};
auto region_yolo = make_shared<opset1::RegionYolo>(
data, num_coords, num_classes, num_regions, do_softmax, mask, axis, end_axis, anchors);
NodeBuilder builder(region_yolo);
auto g_region_yolo = as_type_ptr<opset1::RegionYolo>(builder.create());
EXPECT_EQ(g_region_yolo->get_num_coords(), region_yolo->get_num_coords());
EXPECT_EQ(g_region_yolo->get_num_classes(), region_yolo->get_num_classes());
EXPECT_EQ(g_region_yolo->get_num_regions(), region_yolo->get_num_regions());
EXPECT_EQ(g_region_yolo->get_do_softmax(), region_yolo->get_do_softmax());
EXPECT_EQ(g_region_yolo->get_mask(), region_yolo->get_mask());
EXPECT_EQ(g_region_yolo->get_anchors(), region_yolo->get_anchors());
EXPECT_EQ(g_region_yolo->get_axis(), region_yolo->get_axis());
EXPECT_EQ(g_region_yolo->get_end_axis(), region_yolo->get_end_axis());
}
TEST(attributes, reshape_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::Reshape>();
auto data = make_shared<op::Parameter>(element::i32, Shape{2, 3, 4});
auto pattern = make_shared<op::Parameter>(element::i32, Shape{2});
bool special_zero = true;
auto reshape = make_shared<opset1::Reshape>(data, pattern, special_zero);
NodeBuilder builder(reshape);
auto g_reshape = as_type_ptr<opset1::Reshape>(builder.create());
EXPECT_EQ(g_reshape->get_special_zero(), reshape->get_special_zero());
}
TEST(attributes, reverse_op_enum_mode)
{
FactoryRegistry<Node>::get().register_factory<opset1::Reverse>();
auto data = make_shared<op::Parameter>(element::i32, Shape{200});
auto reversed_axes = make_shared<op::Parameter>(element::i32, Shape{200});
auto reverse = make_shared<opset1::Reverse>(data, reversed_axes, opset1::Reverse::Mode::INDEX);
NodeBuilder builder(reverse);
auto g_reverse = as_type_ptr<opset1::Reverse>(builder.create());
EXPECT_EQ(g_reverse->get_mode(), reverse->get_mode());
}
TEST(attributes, reverse_op_string_mode)
{
FactoryRegistry<Node>::get().register_factory<opset1::Reverse>();
auto data = make_shared<op::Parameter>(element::i32, Shape{200});
auto reversed_axes = make_shared<op::Parameter>(element::i32, Shape{200});
std::string mode = "index";
auto reverse = make_shared<opset1::Reverse>(data, reversed_axes, mode);
NodeBuilder builder(reverse);
auto g_reverse = as_type_ptr<opset1::Reverse>(builder.create());
EXPECT_EQ(g_reverse->get_mode(), reverse->get_mode());
}
TEST(attributes, reverse_sequence_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::ReverseSequence>();
auto data = make_shared<op::Parameter>(element::i32, Shape{2, 3, 4, 2});
auto seq_indices = make_shared<op::Parameter>(element::i32, Shape{4});
auto batch_axis = 2;
auto seq_axis = 1;
auto reverse_sequence =
make_shared<opset1::ReverseSequence>(data, seq_indices, batch_axis, seq_axis);
NodeBuilder builder(reverse_sequence);
auto g_reverse_sequence = as_type_ptr<opset1::ReverseSequence>(builder.create());
EXPECT_EQ(g_reverse_sequence->get_origin_batch_axis(),
reverse_sequence->get_origin_batch_axis());
EXPECT_EQ(g_reverse_sequence->get_origin_sequence_axis(),
reverse_sequence->get_origin_sequence_axis());
}
TEST(attributes, rnn_cell_op_custom_attributes)
{
FactoryRegistry<Node>::get().register_factory<opset1::RNNCell>();
auto X = make_shared<op::Parameter>(element::f32, Shape{2, 3});
auto H = make_shared<op::Parameter>(element::f32, Shape{2, 3});
auto W = make_shared<op::Parameter>(element::f32, Shape{3, 3});
auto R = make_shared<op::Parameter>(element::f32, Shape{3, 3});
const size_t hidden_size = 3;
auto activations = std::vector<std::string>{"sigmoid", "tanh"};
auto activations_alpha = std::vector<float>{1.0, 1.5};
auto activations_beta = std::vector<float>{2.0, 1.0};
float clip = 1.0;
auto rnn_cell = make_shared<opset1::RNNCell>(
X, H, W, R, hidden_size, activations, activations_alpha, activations_beta, clip);
NodeBuilder builder(rnn_cell);
auto g_rnn_cell = as_type_ptr<opset1::RNNCell>(builder.create());
EXPECT_EQ(g_rnn_cell->get_hidden_size(), rnn_cell->get_hidden_size());
EXPECT_EQ(g_rnn_cell->get_clip(), rnn_cell->get_clip());
EXPECT_EQ(g_rnn_cell->get_activations(), rnn_cell->get_activations());
EXPECT_EQ(g_rnn_cell->get_activations_alpha(), rnn_cell->get_activations_alpha());
EXPECT_EQ(g_rnn_cell->get_activations_beta(), rnn_cell->get_activations_beta());
}
TEST(attributes, rnn_cell_op_default_attributes)
{
FactoryRegistry<Node>::get().register_factory<opset1::RNNCell>();
auto X = make_shared<op::Parameter>(element::f32, Shape{2, 3});
auto H = make_shared<op::Parameter>(element::f32, Shape{2, 3});
auto W = make_shared<op::Parameter>(element::f32, Shape{3, 3});
auto R = make_shared<op::Parameter>(element::f32, Shape{3, 3});
const size_t hidden_size = 3;
auto rnn_cell = make_shared<opset1::RNNCell>(X, H, W, R, hidden_size);
NodeBuilder builder(rnn_cell);
auto g_rnn_cell = as_type_ptr<opset1::RNNCell>(builder.create());
EXPECT_EQ(g_rnn_cell->get_hidden_size(), rnn_cell->get_hidden_size());
EXPECT_EQ(g_rnn_cell->get_clip(), rnn_cell->get_clip());
EXPECT_EQ(g_rnn_cell->get_activations(), rnn_cell->get_activations());
EXPECT_EQ(g_rnn_cell->get_activations_alpha(), rnn_cell->get_activations_alpha());
EXPECT_EQ(g_rnn_cell->get_activations_beta(), rnn_cell->get_activations_beta());
}
TEST(attributes, elu_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::Elu>();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment