Unverified Commit 5d8c39e9 authored by Tomasz Socha's avatar Tomasz Socha Committed by GitHub

Add attribute visitor for ops R (#4340)

parent 1b611294
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "region_yolo.hpp" #include "region_yolo.hpp"
#include "ngraph/attribute_visitor.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -43,6 +44,19 @@ op::RegionYolo::RegionYolo(const Output<Node>& input, ...@@ -43,6 +44,19 @@ op::RegionYolo::RegionYolo(const Output<Node>& input,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v0::RegionYolo::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("anchors", m_anchors);
visitor.on_attribute("axis", m_axis);
visitor.on_attribute("coords", m_num_coords);
visitor.on_attribute("classes", m_num_classes);
visitor.on_attribute("end_axis", m_end_axis);
visitor.on_attribute("num", m_num_regions);
visitor.on_attribute("do_softmax", m_do_softmax);
visitor.on_attribute("mask", m_mask);
return true;
}
void op::RegionYolo::validate_and_infer_types() void op::RegionYolo::validate_and_infer_types()
{ {
auto input_et = get_input_element_type(0); auto input_et = get_input_element_type(0);
......
...@@ -55,6 +55,7 @@ namespace ngraph ...@@ -55,6 +55,7 @@ namespace ngraph
const int end_axis, const int end_axis,
const std::vector<float>& anchors = std::vector<float>{}); const std::vector<float>& anchors = std::vector<float>{});
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -182,6 +182,11 @@ static PartialShape infer_output_shape(const op::Range* node, const element::Typ ...@@ -182,6 +182,11 @@ static PartialShape infer_output_shape(const op::Range* node, const element::Typ
return result; return result;
} }
bool ngraph::op::v0::Range::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
void op::Range::validate_and_infer_types() void op::Range::validate_and_infer_types()
{ {
set_input_is_relevant_to_shape(0); set_input_is_relevant_to_shape(0);
......
...@@ -46,6 +46,7 @@ namespace ngraph ...@@ -46,6 +46,7 @@ namespace ngraph
const Output<Node>& stop, const Output<Node>& stop,
const Output<Node>& step); const Output<Node>& step);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -65,6 +65,11 @@ op::RNNCell::RNNCell(const Output<Node>& X, ...@@ -65,6 +65,11 @@ op::RNNCell::RNNCell(const Output<Node>& X,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool op::RNNCell::visit_attributes(AttributeVisitor& visitor)
{
return op::util::RNNCellBase::visit_attributes(visitor);
}
void op::RNNCell::pre_validate_and_infer_types() void op::RNNCell::pre_validate_and_infer_types()
{ {
if (is_dynamic()) if (is_dynamic())
......
...@@ -132,6 +132,7 @@ namespace ngraph ...@@ -132,6 +132,7 @@ namespace ngraph
const std::vector<float>& activations_beta = {}, const std::vector<float>& activations_beta = {},
float clip = 0.f); float clip = 0.f);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual void pre_validate_and_infer_types() override; virtual void pre_validate_and_infer_types() override;
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -156,6 +156,12 @@ op::v1::Reshape::Reshape(const Output<Node>& arg, const Output<Node>& pattern, b ...@@ -156,6 +156,12 @@ op::v1::Reshape::Reshape(const Output<Node>& arg, const Output<Node>& pattern, b
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool op::v1::Reshape::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("special_zero", m_special_zero);
return true;
}
void op::v1::Reshape::validate_and_infer_types() void op::v1::Reshape::validate_and_infer_types()
{ {
auto pattern_et = get_input_element_type(1); auto pattern_et = get_input_element_type(1);
......
...@@ -137,6 +137,7 @@ namespace ngraph ...@@ -137,6 +137,7 @@ namespace ngraph
/// from input shape at the same index. /// from input shape at the same index.
Reshape(const Output<Node>& arg, const Output<Node>& pattern, bool special_zero); Reshape(const Output<Node>& arg, const Output<Node>& pattern, bool special_zero);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
size_t get_version() const override { return 1; } size_t get_version() const override { return 1; }
......
...@@ -35,6 +35,11 @@ op::Result::Result(const Output<Node>& arg, bool needs_default_layout) ...@@ -35,6 +35,11 @@ op::Result::Result(const Output<Node>& arg, bool needs_default_layout)
set_placement_index(input_value(0).get_node()->get_placement_index()); set_placement_index(input_value(0).get_node()->get_placement_index());
} }
bool ngraph::op::v0::Result::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
void op::Result::validate_and_infer_types() void op::Result::validate_and_infer_types()
{ {
NODE_VALIDATION_CHECK( NODE_VALIDATION_CHECK(
......
...@@ -38,6 +38,7 @@ namespace ngraph ...@@ -38,6 +38,7 @@ namespace ngraph
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Result(const Output<Node>& arg, bool needs_default_layout = false); Result(const Output<Node>& arg, bool needs_default_layout = false);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <algorithm> #include <algorithm>
#include <sstream> #include <sstream>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/function.hpp" #include "ngraph/function.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/reverse.hpp" #include "ngraph/op/reverse.hpp"
...@@ -91,6 +92,12 @@ op::v1::Reverse::Reverse(const Output<Node>& data, ...@@ -91,6 +92,12 @@ op::v1::Reverse::Reverse(const Output<Node>& data,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v1::Reverse::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("mode", m_mode);
return true;
}
void op::v1::Reverse::validate_and_infer_types() void op::v1::Reverse::validate_and_infer_types()
{ {
if (m_mode == Mode::MASK) if (m_mode == Mode::MASK)
...@@ -194,3 +201,22 @@ op::v1::Reverse::Mode op::v1::Reverse::mode_from_string(const std::string& mode) ...@@ -194,3 +201,22 @@ op::v1::Reverse::Mode op::v1::Reverse::mode_from_string(const std::string& mode)
return allowed_values.at(mode); return allowed_values.at(mode);
} }
namespace ngraph
{
template <>
EnumNames<op::v1::Reverse::Mode>& EnumNames<op::v1::Reverse::Mode>::get()
{
static auto enum_names = EnumNames<op::v1::Reverse::Mode>(
"op::v1::Reverse::Mode",
{{"index", op::v1::Reverse::Mode::INDEX}, {"mask", op::v1::Reverse::Mode::MASK}});
return enum_names;
}
constexpr DiscreteTypeInfo AttributeAdapter<op::v1::Reverse::Mode>::type_info;
std::ostream& operator<<(std::ostream& s, const op::v1::Reverse::Mode& type)
{
return s << as_string(type);
}
} // namespace ngraph
...@@ -108,6 +108,7 @@ namespace ngraph ...@@ -108,6 +108,7 @@ namespace ngraph
const Output<Node>& reversed_axes, const Output<Node>& reversed_axes,
const Mode mode); const Mode mode);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
...@@ -135,4 +136,20 @@ namespace ngraph ...@@ -135,4 +136,20 @@ namespace ngraph
// default opset version // default opset version
using v0::Reverse; using v0::Reverse;
} }
std::ostream& operator<<(std::ostream& s, const op::v1::Reverse::Mode& type);
template <>
class NGRAPH_API AttributeAdapter<op::v1::Reverse::Mode>
: public EnumAttributeAdapterBase<op::v1::Reverse::Mode>
{
public:
AttributeAdapter(op::v1::Reverse::Mode& value)
: EnumAttributeAdapterBase<op::v1::Reverse::Mode>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::v1::Reverse::Mode>", 1};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
} }
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
#include "ngraph/op/reverse_sequence.hpp" #include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/validation_util.hpp" #include "ngraph/validation_util.hpp"
...@@ -39,6 +40,13 @@ op::ReverseSequence::ReverseSequence(const Output<Node>& arg, ...@@ -39,6 +40,13 @@ op::ReverseSequence::ReverseSequence(const Output<Node>& arg,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v0::ReverseSequence::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("batch_axis", m_batch_axis);
visitor.on_attribute("seq_axis", m_seq_axis);
return true;
}
void op::ReverseSequence::validate_and_infer_types() void op::ReverseSequence::validate_and_infer_types()
{ {
auto input_shape = get_input_partial_shape(0); auto input_shape = get_input_partial_shape(0);
......
...@@ -38,6 +38,7 @@ namespace ngraph ...@@ -38,6 +38,7 @@ namespace ngraph
int64_t batch_axis, int64_t batch_axis,
int64_t seq_axis); int64_t seq_axis);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/op/util/arithmetic_reductions_keep_dims.hpp" #include "ngraph/op/util/arithmetic_reductions_keep_dims.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/validation_util.hpp" #include "ngraph/validation_util.hpp"
...@@ -30,6 +31,12 @@ op::util::ArithmeticReductionKeepDims::ArithmeticReductionKeepDims( ...@@ -30,6 +31,12 @@ op::util::ArithmeticReductionKeepDims::ArithmeticReductionKeepDims(
{ {
} }
bool ngraph::op::util::ArithmeticReductionKeepDims::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("keep_dims", m_keep_dims);
return true;
}
void op::util::ArithmeticReductionKeepDims::validate_and_infer_types() void op::util::ArithmeticReductionKeepDims::validate_and_infer_types()
{ {
if (m_keep_dims) if (m_keep_dims)
......
...@@ -37,6 +37,8 @@ namespace ngraph ...@@ -37,6 +37,8 @@ namespace ngraph
const Output<Node>& reduction_axes, const Output<Node>& reduction_axes,
bool keep_dims = false); bool keep_dims = false);
bool visit_attributes(AttributeVisitor& visitor) override;
public: public:
void validate_and_infer_types() override; void validate_and_infer_types() override;
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/op/util/logical_reduction_keep_dims.hpp" #include "ngraph/op/util/logical_reduction_keep_dims.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/validation_util.hpp" #include "ngraph/validation_util.hpp"
...@@ -30,6 +31,12 @@ op::util::LogicalReductionKeepDims::LogicalReductionKeepDims( ...@@ -30,6 +31,12 @@ op::util::LogicalReductionKeepDims::LogicalReductionKeepDims(
{ {
} }
bool ngraph::op::util::LogicalReductionKeepDims::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("keep_dims", m_keep_dims);
return true;
}
void op::util::LogicalReductionKeepDims::validate_and_infer_types() void op::util::LogicalReductionKeepDims::validate_and_infer_types()
{ {
if (m_keep_dims) if (m_keep_dims)
......
...@@ -37,6 +37,8 @@ namespace ngraph ...@@ -37,6 +37,8 @@ namespace ngraph
const Output<Node>& reduction_axes, const Output<Node>& reduction_axes,
const bool keep_dims = false); const bool keep_dims = false);
bool visit_attributes(AttributeVisitor& visitor) override;
public: public:
void validate_and_infer_types() override; void validate_and_infer_types() override;
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <algorithm> #include <algorithm>
#include <iterator> #include <iterator>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/fused/clamp.hpp" #include "ngraph/op/fused/clamp.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
...@@ -48,6 +49,16 @@ op::util::RNNCellBase::RNNCellBase(size_t hidden_size, ...@@ -48,6 +49,16 @@ op::util::RNNCellBase::RNNCellBase(size_t hidden_size,
{ {
} }
bool ngraph::op::util::RNNCellBase::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("hidden_size", m_hidden_size);
visitor.on_attribute("activations", m_activations);
visitor.on_attribute("activations_alpha", m_activations_alpha);
visitor.on_attribute("activations_beta", m_activations_beta);
visitor.on_attribute("clip", m_clip);
return true;
}
op::util::ActivationFunction op::util::RNNCellBase::get_activation_function(size_t idx) const op::util::ActivationFunction op::util::RNNCellBase::get_activation_function(size_t idx) const
{ {
op::util::ActivationFunction afunc = get_activation_func_by_name(m_activations.at(idx)); op::util::ActivationFunction afunc = get_activation_func_by_name(m_activations.at(idx));
......
...@@ -58,6 +58,7 @@ namespace ngraph ...@@ -58,6 +58,7 @@ namespace ngraph
RNNCellBase() = default; RNNCellBase() = default;
virtual bool visit_attributes(AttributeVisitor& visitor);
std::size_t get_hidden_size() const { return m_hidden_size; } std::size_t get_hidden_size() const { return m_hidden_size; }
float get_clip() const { return m_clip; } float get_clip() const { return m_clip; }
const std::vector<std::string>& get_activations() const { return m_activations; } const std::vector<std::string>& get_activations() const { return m_activations; }
......
...@@ -43,3 +43,8 @@ void op::util::UnaryElementwiseArithmetic::validate_and_infer_types() ...@@ -43,3 +43,8 @@ void op::util::UnaryElementwiseArithmetic::validate_and_infer_types()
{ {
validate_and_infer_elementwise_arithmetic(); validate_and_infer_elementwise_arithmetic();
} }
bool op::util::UnaryElementwiseArithmetic::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
...@@ -68,6 +68,7 @@ namespace ngraph ...@@ -68,6 +68,7 @@ namespace ngraph
public: public:
void validate_and_infer_types() override; void validate_and_infer_types() override;
bool is_unary_elementwise_arithmetic() const override { return true; } bool is_unary_elementwise_arithmetic() const override { return true; }
bool visit_attributes(AttributeVisitor& visitor) override;
}; };
} }
} }
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment