Unverified Commit 5d8c39e9 authored by Tomasz Socha's avatar Tomasz Socha Committed by GitHub

Add attribute visitor for ops R (#4340)

parent 1b611294
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include "region_yolo.hpp"
#include "ngraph/attribute_visitor.hpp"
using namespace std;
using namespace ngraph;
......@@ -43,6 +44,19 @@ op::RegionYolo::RegionYolo(const Output<Node>& input,
constructor_validate_and_infer_types();
}
bool ngraph::op::v0::RegionYolo::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("anchors", m_anchors);
visitor.on_attribute("axis", m_axis);
visitor.on_attribute("coords", m_num_coords);
visitor.on_attribute("classes", m_num_classes);
visitor.on_attribute("end_axis", m_end_axis);
visitor.on_attribute("num", m_num_regions);
visitor.on_attribute("do_softmax", m_do_softmax);
visitor.on_attribute("mask", m_mask);
return true;
}
void op::RegionYolo::validate_and_infer_types()
{
auto input_et = get_input_element_type(0);
......
......@@ -55,6 +55,7 @@ namespace ngraph
const int end_axis,
const std::vector<float>& anchors = std::vector<float>{});
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
......
......@@ -182,6 +182,11 @@ static PartialShape infer_output_shape(const op::Range* node, const element::Typ
return result;
}
bool ngraph::op::v0::Range::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
void op::Range::validate_and_infer_types()
{
set_input_is_relevant_to_shape(0);
......
......@@ -46,6 +46,7 @@ namespace ngraph
const Output<Node>& stop,
const Output<Node>& step);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
......
......@@ -65,6 +65,11 @@ op::RNNCell::RNNCell(const Output<Node>& X,
constructor_validate_and_infer_types();
}
bool op::RNNCell::visit_attributes(AttributeVisitor& visitor)
{
return op::util::RNNCellBase::visit_attributes(visitor);
}
void op::RNNCell::pre_validate_and_infer_types()
{
if (is_dynamic())
......
......@@ -132,6 +132,7 @@ namespace ngraph
const std::vector<float>& activations_beta = {},
float clip = 0.f);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual void pre_validate_and_infer_types() override;
virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node>
......
......@@ -156,6 +156,12 @@ op::v1::Reshape::Reshape(const Output<Node>& arg, const Output<Node>& pattern, b
constructor_validate_and_infer_types();
}
bool op::v1::Reshape::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("special_zero", m_special_zero);
return true;
}
void op::v1::Reshape::validate_and_infer_types()
{
auto pattern_et = get_input_element_type(1);
......
......@@ -137,6 +137,7 @@ namespace ngraph
/// from input shape at the same index.
Reshape(const Output<Node>& arg, const Output<Node>& pattern, bool special_zero);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
size_t get_version() const override { return 1; }
......
......@@ -35,6 +35,11 @@ op::Result::Result(const Output<Node>& arg, bool needs_default_layout)
set_placement_index(input_value(0).get_node()->get_placement_index());
}
bool ngraph::op::v0::Result::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
void op::Result::validate_and_infer_types()
{
NODE_VALIDATION_CHECK(
......
......@@ -38,6 +38,7 @@ namespace ngraph
/// \param arg Node that produces the input tensor.
Result(const Output<Node>& arg, bool needs_default_layout = false);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
......
......@@ -17,6 +17,7 @@
#include <algorithm>
#include <sstream>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/function.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/reverse.hpp"
......@@ -91,6 +92,12 @@ op::v1::Reverse::Reverse(const Output<Node>& data,
constructor_validate_and_infer_types();
}
bool ngraph::op::v1::Reverse::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("mode", m_mode);
return true;
}
void op::v1::Reverse::validate_and_infer_types()
{
if (m_mode == Mode::MASK)
......@@ -194,3 +201,22 @@ op::v1::Reverse::Mode op::v1::Reverse::mode_from_string(const std::string& mode)
return allowed_values.at(mode);
}
namespace ngraph
{
template <>
EnumNames<op::v1::Reverse::Mode>& EnumNames<op::v1::Reverse::Mode>::get()
{
static auto enum_names = EnumNames<op::v1::Reverse::Mode>(
"op::v1::Reverse::Mode",
{{"index", op::v1::Reverse::Mode::INDEX}, {"mask", op::v1::Reverse::Mode::MASK}});
return enum_names;
}
constexpr DiscreteTypeInfo AttributeAdapter<op::v1::Reverse::Mode>::type_info;
std::ostream& operator<<(std::ostream& s, const op::v1::Reverse::Mode& type)
{
return s << as_string(type);
}
} // namespace ngraph
......@@ -108,6 +108,7 @@ namespace ngraph
const Output<Node>& reversed_axes,
const Mode mode);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
......@@ -135,4 +136,20 @@ namespace ngraph
// default opset version
using v0::Reverse;
}
std::ostream& operator<<(std::ostream& s, const op::v1::Reverse::Mode& type);
template <>
class NGRAPH_API AttributeAdapter<op::v1::Reverse::Mode>
: public EnumAttributeAdapterBase<op::v1::Reverse::Mode>
{
public:
AttributeAdapter(op::v1::Reverse::Mode& value)
: EnumAttributeAdapterBase<op::v1::Reverse::Mode>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::v1::Reverse::Mode>", 1};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
}
......@@ -17,6 +17,7 @@
#include <algorithm>
#include <memory>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/validation_util.hpp"
......@@ -39,6 +40,13 @@ op::ReverseSequence::ReverseSequence(const Output<Node>& arg,
constructor_validate_and_infer_types();
}
bool ngraph::op::v0::ReverseSequence::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("batch_axis", m_batch_axis);
visitor.on_attribute("seq_axis", m_seq_axis);
return true;
}
void op::ReverseSequence::validate_and_infer_types()
{
auto input_shape = get_input_partial_shape(0);
......
......@@ -38,6 +38,7 @@ namespace ngraph
int64_t batch_axis,
int64_t seq_axis);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
......
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/util/arithmetic_reductions_keep_dims.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/validation_util.hpp"
......@@ -30,6 +31,12 @@ op::util::ArithmeticReductionKeepDims::ArithmeticReductionKeepDims(
{
}
bool ngraph::op::util::ArithmeticReductionKeepDims::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("keep_dims", m_keep_dims);
return true;
}
void op::util::ArithmeticReductionKeepDims::validate_and_infer_types()
{
if (m_keep_dims)
......
......@@ -37,6 +37,8 @@ namespace ngraph
const Output<Node>& reduction_axes,
bool keep_dims = false);
bool visit_attributes(AttributeVisitor& visitor) override;
public:
void validate_and_infer_types() override;
......
......@@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/util/logical_reduction_keep_dims.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/validation_util.hpp"
......@@ -30,6 +31,12 @@ op::util::LogicalReductionKeepDims::LogicalReductionKeepDims(
{
}
bool ngraph::op::util::LogicalReductionKeepDims::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("keep_dims", m_keep_dims);
return true;
}
void op::util::LogicalReductionKeepDims::validate_and_infer_types()
{
if (m_keep_dims)
......
......@@ -37,6 +37,8 @@ namespace ngraph
const Output<Node>& reduction_axes,
const bool keep_dims = false);
bool visit_attributes(AttributeVisitor& visitor) override;
public:
void validate_and_infer_types() override;
......
......@@ -17,6 +17,7 @@
#include <algorithm>
#include <iterator>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/fused/clamp.hpp"
#include "ngraph/op/multiply.hpp"
......@@ -48,6 +49,16 @@ op::util::RNNCellBase::RNNCellBase(size_t hidden_size,
{
}
bool ngraph::op::util::RNNCellBase::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("hidden_size", m_hidden_size);
visitor.on_attribute("activations", m_activations);
visitor.on_attribute("activations_alpha", m_activations_alpha);
visitor.on_attribute("activations_beta", m_activations_beta);
visitor.on_attribute("clip", m_clip);
return true;
}
op::util::ActivationFunction op::util::RNNCellBase::get_activation_function(size_t idx) const
{
op::util::ActivationFunction afunc = get_activation_func_by_name(m_activations.at(idx));
......
......@@ -58,6 +58,7 @@ namespace ngraph
RNNCellBase() = default;
virtual bool visit_attributes(AttributeVisitor& visitor);
std::size_t get_hidden_size() const { return m_hidden_size; }
float get_clip() const { return m_clip; }
const std::vector<std::string>& get_activations() const { return m_activations; }
......
......@@ -43,3 +43,8 @@ void op::util::UnaryElementwiseArithmetic::validate_and_infer_types()
{
validate_and_infer_elementwise_arithmetic();
}
bool op::util::UnaryElementwiseArithmetic::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
......@@ -68,6 +68,7 @@ namespace ngraph
public:
void validate_and_infer_types() override;
bool is_unary_elementwise_arithmetic() const override { return true; }
bool visit_attributes(AttributeVisitor& visitor) override;
};
}
}
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment