Unverified Commit 61b4e6f9 authored by Tomasz Socha's avatar Tomasz Socha Committed by GitHub

Add attribute visitor for ops E-L (#4346)

* Add new attribute adapters for vector<float> and vector<string> types

* Add attribute visitor to LSTMSequence op

* Add attribute visitor to LSTMCell op

* Add attribute visitor to LRN op

* Add attribute visitor to LogicalNot op

* Add attribute visitor to Log op

* Add attribute visitor to HardSigmoid op

* Add attribute visitor to GRN op

* Add attribute visitor to GroupConvolutionBackpropData op

* Add attribute visitor to GroupConvolution op

* Add attribute visitor to GatherTree op

* Add attribute visitor to Gather op

* Add attribute visitor to Floor op

* Add attribute visitor to FakeQuantize op

* Add attribute visitor to Exp op

* Add attribute visitor to Erf op

* Add attribute visitor to Elu op

* Add test for LRN, LSTMSewuence and LSTMCell

* Add test for Elu

* Add test for FakeQuantize

* Revert user_op test missed

* Add test for GRN

* Fix for CoordinateDiff

* Add tests for GroupConvolution and GroupCinvolutionBackpropData

* Tests alphabetical reorder
Co-authored-by: 's avatarKatarzyna Mitrus <katarzyna.mitrus@intel.com>
parent 5a5579f7
...@@ -242,4 +242,40 @@ namespace ngraph ...@@ -242,4 +242,40 @@ namespace ngraph
m_value = copy_from<vector<uint64_t>>(value); m_value = copy_from<vector<uint64_t>>(value);
m_buffer_valid = false; m_buffer_valid = false;
} }
constexpr DiscreteTypeInfo AttributeAdapter<vector<float>>::type_info;
const vector<float>& AttributeAdapter<vector<float>>::get()
{
if (!m_buffer_valid)
{
m_buffer = copy_from<vector<float>>(m_value);
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<vector<float>>::set(const vector<float>& value)
{
m_value = copy_from<vector<float>>(value);
m_buffer_valid = false;
}
constexpr DiscreteTypeInfo AttributeAdapter<vector<string>>::type_info;
const vector<string>& AttributeAdapter<vector<string>>::get()
{
if (!m_buffer_valid)
{
m_buffer = copy_from<vector<string>>(m_value);
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<vector<string>>::set(const vector<string>& value)
{
m_value = copy_from<vector<string>>(value);
m_buffer_valid = false;
}
} }
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#pragma once #pragma once
#include <string>
#include <type_traits> #include <type_traits>
#include <vector> #include <vector>
...@@ -299,6 +300,39 @@ namespace ngraph ...@@ -299,6 +300,39 @@ namespace ngraph
void set(const std::vector<int64_t>& value) override; void set(const std::vector<int64_t>& value) override;
}; };
template <>
class NGRAPH_API AttributeAdapter<std::vector<float>>
: public ValueReference<std::vector<float>>, public ValueAccessor<std::vector<float>>
{
public:
AttributeAdapter(std::vector<float>& value)
: ValueReference<std::vector<float>>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<float>>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
const std::vector<float>& get() override;
void set(const std::vector<float>& value) override;
};
template <>
class NGRAPH_API AttributeAdapter<std::vector<std::string>>
: public ValueReference<std::vector<std::string>>,
public ValueAccessor<std::vector<std::string>>
{
public:
AttributeAdapter(std::vector<std::string>& value)
: ValueReference<std::vector<std::string>>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<string>>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
const std::vector<std::string>& get() override;
void set(const std::vector<std::string>& value) override;
};
template <typename A, typename B> template <typename A, typename B>
A copy_from(B& b) A copy_from(B& b)
{ {
......
...@@ -46,16 +46,25 @@ namespace ngraph ...@@ -46,16 +46,25 @@ namespace ngraph
{ {
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter)); on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}; };
virtual void on_adapter(const std::string& name, ValueAccessor<int64_t>& adapter)
{
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}
virtual void on_adapter(const std::string& name, ValueAccessor<double>& adapter)
{
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}
virtual void on_adapter(const std::string& name, virtual void on_adapter(const std::string& name,
ValueAccessor<std::vector<int64_t>>& adapter) ValueAccessor<std::vector<int64_t>>& adapter)
{ {
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter)); on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
} }
virtual void on_adapter(const std::string& name, ValueAccessor<int64_t>& adapter) virtual void on_adapter(const std::string& name, ValueAccessor<std::vector<float>>& adapter)
{ {
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter)); on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
} }
virtual void on_adapter(const std::string& name, ValueAccessor<double>& adapter) virtual void on_adapter(const std::string& name,
ValueAccessor<std::vector<std::string>>& adapter)
{ {
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter)); on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
} }
...@@ -68,5 +77,11 @@ namespace ngraph ...@@ -68,5 +77,11 @@ namespace ngraph
AttributeAdapter<T> adapter(value); AttributeAdapter<T> adapter(value);
on_adapter(name, adapter); on_adapter(name, adapter);
} }
void on_attribute(const std::string& name, op::AutoBroadcastSpec& value)
{
AttributeAdapter<op::AutoBroadcastType> adapter(value.m_type);
on_adapter(name, adapter);
}
}; };
} }
...@@ -80,7 +80,7 @@ const vector<int64_t>& AttributeAdapter<CoordinateDiff>::get() ...@@ -80,7 +80,7 @@ const vector<int64_t>& AttributeAdapter<CoordinateDiff>::get()
void AttributeAdapter<CoordinateDiff>::set(const vector<int64_t>& value) void AttributeAdapter<CoordinateDiff>::set(const vector<int64_t>& value)
{ {
m_value = copy_from<CoordinateDiff>(m_value); m_value = copy_from<CoordinateDiff>(value);
m_buffer_valid = false; m_buffer_valid = false;
} }
......
...@@ -23,6 +23,11 @@ using namespace ngraph; ...@@ -23,6 +23,11 @@ using namespace ngraph;
constexpr NodeTypeInfo op::Erf::type_info; constexpr NodeTypeInfo op::Erf::type_info;
bool ngraph::op::v0::Erf::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
shared_ptr<Node> op::Erf::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Erf::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -32,6 +32,7 @@ namespace ngraph ...@@ -32,6 +32,7 @@ namespace ngraph
Erf() = default; Erf() = default;
Erf(const Output<Node>& arg); Erf(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
}; };
......
...@@ -28,6 +28,11 @@ op::Exp::Exp(const Output<Node>& arg) ...@@ -28,6 +28,11 @@ op::Exp::Exp(const Output<Node>& arg)
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v0::Exp::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
shared_ptr<Node> op::Exp::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Exp::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -37,6 +37,7 @@ namespace ngraph ...@@ -37,6 +37,7 @@ namespace ngraph
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Exp(const Output<Node>& arg); Exp(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -27,6 +27,11 @@ op::Floor::Floor(const Output<Node>& arg) ...@@ -27,6 +27,11 @@ op::Floor::Floor(const Output<Node>& arg)
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v0::Floor::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
shared_ptr<Node> op::Floor::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Floor::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -37,6 +37,7 @@ namespace ngraph ...@@ -37,6 +37,7 @@ namespace ngraph
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Floor(const Output<Node>& arg); Floor(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
}; };
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/op/fused/elu.hpp" #include "ngraph/op/fused/elu.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/autobroadcast.hpp" #include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/builder/make_constant.hpp" #include "ngraph/builder/make_constant.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
...@@ -37,6 +38,12 @@ op::Elu::Elu(const Output<Node>& data, const double alpha) ...@@ -37,6 +38,12 @@ op::Elu::Elu(const Output<Node>& data, const double alpha)
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v0::Elu::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("alpha", m_alpha);
return true;
}
NodeVector op::Elu::decompose_op() const NodeVector op::Elu::decompose_op() const
{ {
auto data = input_value(0); auto data = input_value(0);
......
...@@ -42,6 +42,7 @@ namespace ngraph ...@@ -42,6 +42,7 @@ namespace ngraph
/// \param alpha Multiplier for negative values /// \param alpha Multiplier for negative values
Elu(const Output<Node>& data, const double alpha); Elu(const Output<Node>& data, const double alpha);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <memory> #include <memory>
#include "fake_quantize.hpp" #include "fake_quantize.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/autobroadcast.hpp" #include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
...@@ -80,6 +81,13 @@ void op::FakeQuantize::validate_and_infer_types() ...@@ -80,6 +81,13 @@ void op::FakeQuantize::validate_and_infer_types()
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0)); set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
} }
bool ngraph::op::v0::FakeQuantize::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("levels", m_levels);
visitor.on_attribute("auto_broadcast", m_auto_broadcast);
return true;
}
NodeVector op::FakeQuantize::decompose_op() const NodeVector op::FakeQuantize::decompose_op() const
{ {
Output<Node> data{input_value(0)}; Output<Node> data{input_value(0)};
......
...@@ -67,6 +67,7 @@ namespace ngraph ...@@ -67,6 +67,7 @@ namespace ngraph
const AutoBroadcastSpec& auto_broadcast = const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY)); AutoBroadcastSpec(AutoBroadcastType::NUMPY));
bool visit_attributes(AttributeVisitor& visitor) override;
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
virtual void validate_and_infer_types() override; virtual void validate_and_infer_types() override;
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <iterator> #include <iterator>
#include "grn.hpp" #include "grn.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/axis_set.hpp" #include "ngraph/axis_set.hpp"
#include "ngraph/builder/norm.hpp" #include "ngraph/builder/norm.hpp"
#include "ngraph/builder/reshape.hpp" #include "ngraph/builder/reshape.hpp"
...@@ -36,6 +37,12 @@ op::GRN::GRN(const Output<Node>& data, float bias) ...@@ -36,6 +37,12 @@ op::GRN::GRN(const Output<Node>& data, float bias)
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v0::GRN::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("bias", m_bias);
return true;
}
void op::GRN::pre_validate_and_infer_types() void op::GRN::pre_validate_and_infer_types()
{ {
const auto& data_pshape = get_input_partial_shape(0); const auto& data_pshape = get_input_partial_shape(0);
......
...@@ -42,6 +42,7 @@ namespace ngraph ...@@ -42,6 +42,7 @@ namespace ngraph
/// ///
GRN(const Output<Node>& data, float bias); GRN(const Output<Node>& data, float bias);
bool visit_attributes(AttributeVisitor& visitor) override;
float get_bias() const { return m_bias; } float get_bias() const { return m_bias; }
virtual void pre_validate_and_infer_types() override; virtual void pre_validate_and_infer_types() override;
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "group_conv.hpp" #include "group_conv.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/reshape.hpp" #include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp" #include "ngraph/builder/split.hpp"
#include "ngraph/op/concat.hpp" #include "ngraph/op/concat.hpp"
...@@ -57,6 +58,16 @@ op::v1::GroupConvolution::GroupConvolution(const Output<Node>& data_batch, ...@@ -57,6 +58,16 @@ op::v1::GroupConvolution::GroupConvolution(const Output<Node>& data_batch,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v1::GroupConvolution::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("strides", m_strides);
visitor.on_attribute("pads_begin", m_pads_begin);
visitor.on_attribute("pads_end", m_pads_end);
visitor.on_attribute("dilations", m_dilations);
visitor.on_attribute("auto_pad", m_auto_pad);
return true;
}
void op::v1::GroupConvolution::validate_and_infer_types() void op::v1::GroupConvolution::validate_and_infer_types()
{ {
const PartialShape& data_batch_pshape = get_input_partial_shape(0); const PartialShape& data_batch_pshape = get_input_partial_shape(0);
...@@ -219,6 +230,17 @@ op::v1::GroupConvolutionBackpropData::GroupConvolutionBackpropData( ...@@ -219,6 +230,17 @@ op::v1::GroupConvolutionBackpropData::GroupConvolutionBackpropData(
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v1::GroupConvolutionBackpropData::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("strides", m_strides);
visitor.on_attribute("pads_begin", m_pads_begin);
visitor.on_attribute("pads_end", m_pads_end);
visitor.on_attribute("dilations", m_dilations);
visitor.on_attribute("auto_pad", m_auto_pad);
visitor.on_attribute("output_padding", m_output_padding);
return true;
}
bool op::v1::GroupConvolutionBackpropData::is_dynamic() const bool op::v1::GroupConvolutionBackpropData::is_dynamic() const
{ {
bool is_dynamic = Node::is_dynamic(); bool is_dynamic = Node::is_dynamic();
......
...@@ -61,6 +61,8 @@ namespace ngraph ...@@ -61,6 +61,8 @@ namespace ngraph
const CoordinateDiff& pads_end, const CoordinateDiff& pads_end,
const Strides& dilations, const Strides& dilations,
const PadType& auto_pad = PadType::EXPLICIT); const PadType& auto_pad = PadType::EXPLICIT);
bool visit_attributes(AttributeVisitor& visitor) override;
// TODO - Remove supports_decompose and validate_and_infer_type once op supports // TODO - Remove supports_decompose and validate_and_infer_type once op supports
// decomposition // decomposition
bool supports_decompose() const override { return false; } bool supports_decompose() const override { return false; }
...@@ -187,6 +189,7 @@ namespace ngraph ...@@ -187,6 +189,7 @@ namespace ngraph
const PadType& auto_pad = PadType::EXPLICIT, const PadType& auto_pad = PadType::EXPLICIT,
const CoordinateDiff& output_padding = {}); const CoordinateDiff& output_padding = {});
bool visit_attributes(AttributeVisitor& visitor) override;
virtual bool is_dynamic() const override; virtual bool is_dynamic() const override;
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
virtual void pre_validate_and_infer_types() override; virtual void pre_validate_and_infer_types() override;
......
...@@ -37,6 +37,11 @@ op::HardSigmoid::HardSigmoid(const Output<Node>& data, ...@@ -37,6 +37,11 @@ op::HardSigmoid::HardSigmoid(const Output<Node>& data,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v0::HardSigmoid::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
void op::HardSigmoid::pre_validate_and_infer_types() void op::HardSigmoid::pre_validate_and_infer_types()
{ {
const auto& alpha_pshape = get_input_partial_shape(1); const auto& alpha_pshape = get_input_partial_shape(1);
......
...@@ -46,6 +46,7 @@ namespace ngraph ...@@ -46,6 +46,7 @@ namespace ngraph
const Output<Node>& alpha, const Output<Node>& alpha,
const Output<Node>& beta); const Output<Node>& beta);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual void pre_validate_and_infer_types() override; virtual void pre_validate_and_infer_types() override;
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <cmath> #include <cmath>
#include <functional> #include <functional>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/reshape.hpp" #include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp" #include "ngraph/builder/split.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
...@@ -107,6 +108,19 @@ op::LSTMCell::LSTMCell(const Output<Node>& X, ...@@ -107,6 +108,19 @@ op::LSTMCell::LSTMCell(const Output<Node>& X,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v0::LSTMCell::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("hidden_size", m_hidden_size);
visitor.on_attribute("activations", m_activations);
visitor.on_attribute("activations_alpha", m_activations_alpha);
visitor.on_attribute("activations_beta", m_activations_beta);
visitor.on_attribute("clip", m_clip);
visitor.on_attribute("input_forget", m_input_forget);
visitor.on_attribute("weights_format", m_weights_format);
return true;
}
void op::LSTMCell::pre_validate_and_infer_types() void op::LSTMCell::pre_validate_and_infer_types()
{ {
set_output_size(2); set_output_size(2);
...@@ -386,3 +400,26 @@ shared_ptr<Node> op::LSTMCell::copy_with_new_args(const NodeVector& new_args) co ...@@ -386,3 +400,26 @@ shared_ptr<Node> op::LSTMCell::copy_with_new_args(const NodeVector& new_args) co
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
} }
} }
namespace ngraph
{
template <>
EnumNames<op::LSTMWeightsFormat>& EnumNames<op::LSTMWeightsFormat>::get()
{
static auto enum_names =
EnumNames<op::LSTMWeightsFormat>("op::LSTMWeightsFormat",
{{"fico", op::LSTMWeightsFormat::FICO},
{"icof", op::LSTMWeightsFormat::ICOF},
{"ifco", op::LSTMWeightsFormat::IFCO},
{"ifoc", op::LSTMWeightsFormat::IFOC},
{"iofc", op::LSTMWeightsFormat::IOFC}});
return enum_names;
}
constexpr DiscreteTypeInfo AttributeAdapter<op::LSTMWeightsFormat>::type_info;
std::ostream& operator<<(std::ostream& s, const op::LSTMWeightsFormat& type)
{
return s << as_string(type);
}
} // namespace ngraph
...@@ -224,6 +224,7 @@ namespace ngraph ...@@ -224,6 +224,7 @@ namespace ngraph
float clip = 0.f, float clip = 0.f,
bool input_forget = false); bool input_forget = false);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual void pre_validate_and_infer_types() override; virtual void pre_validate_and_infer_types() override;
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
...@@ -284,4 +285,20 @@ namespace ngraph ...@@ -284,4 +285,20 @@ namespace ngraph
} }
using v0::LSTMCell; using v0::LSTMCell;
} // namespace op } // namespace op
std::ostream& operator<<(std::ostream& s, const op::LSTMWeightsFormat& type);
template <>
class NGRAPH_API AttributeAdapter<op::LSTMWeightsFormat>
: public EnumAttributeAdapterBase<op::LSTMWeightsFormat>
{
public:
AttributeAdapter(op::LSTMWeightsFormat& value)
: EnumAttributeAdapterBase<op::LSTMWeightsFormat>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::LSTMWeightsFormat>", 1};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
} // namespace ngraph } // namespace ngraph
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ngraph/op/fused/lstm_sequence.hpp" #include "ngraph/op/fused/lstm_sequence.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/autobroadcast.hpp" #include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/builder/reshape.hpp" #include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp" #include "ngraph/builder/split.hpp"
...@@ -32,6 +33,19 @@ using namespace ngraph; ...@@ -32,6 +33,19 @@ using namespace ngraph;
using namespace std; using namespace std;
constexpr NodeTypeInfo op::LSTMSequence::type_info; constexpr NodeTypeInfo op::LSTMSequence::type_info;
bool ngraph::op::v0::LSTMSequence::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("hidden_size", m_hidden_size);
visitor.on_attribute("activations", m_activations);
visitor.on_attribute("activations_alpha", m_activations_alpha);
visitor.on_attribute("activations_beta", m_activations_beta);
visitor.on_attribute("clip", m_clip_threshold);
visitor.on_attribute("direction", m_direction);
visitor.on_attribute("input_forget", m_input_forget);
visitor.on_attribute("weights_format", m_weights_format);
return true;
}
NodeVector op::LSTMSequence::decompose_op() const NodeVector op::LSTMSequence::decompose_op() const
{ {
NodeVector results; NodeVector results;
...@@ -247,3 +261,24 @@ shared_ptr<Node> op::LSTMSequence::prepare_input(Output<Node> node, bool is_reve ...@@ -247,3 +261,24 @@ shared_ptr<Node> op::LSTMSequence::prepare_input(Output<Node> node, bool is_reve
// Since we have forward LSTM we can squeeze `num_directions` axis from inputs. // Since we have forward LSTM we can squeeze `num_directions` axis from inputs.
return builder::squeeze(tmp); return builder::squeeze(tmp);
} }
namespace ngraph
{
template <>
EnumNames<op::v0::LSTMSequence::direction>& EnumNames<op::v0::LSTMSequence::direction>::get()
{
static auto enum_names = EnumNames<op::v0::LSTMSequence::direction>(
"op::v0::LSTMSequence::direction",
{{"forward", op::v0::LSTMSequence::direction::FORWARD},
{"reverse", op::v0::LSTMSequence::direction::REVERSE},
{"bidirectional", op::v0::LSTMSequence::direction::BIDIRECTIONAL}});
return enum_names;
}
constexpr DiscreteTypeInfo AttributeAdapter<op::v0::LSTMSequence::direction>::type_info;
std::ostream& operator<<(std::ostream& s, const op::v0::LSTMSequence::direction& type)
{
return s << as_string(type);
}
} // namespace ngraph
...@@ -135,6 +135,7 @@ namespace ngraph ...@@ -135,6 +135,7 @@ namespace ngraph
{ {
} }
bool visit_attributes(AttributeVisitor& visitor) override;
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
...@@ -185,4 +186,21 @@ namespace ngraph ...@@ -185,4 +186,21 @@ namespace ngraph
} }
using v0::LSTMSequence; using v0::LSTMSequence;
} // namespace op } // namespace op
std::ostream& operator<<(std::ostream& s, const op::v0::LSTMSequence::direction& type);
template <>
class NGRAPH_API AttributeAdapter<op::v0::LSTMSequence::direction>
: public EnumAttributeAdapterBase<op::v0::LSTMSequence::direction>
{
public:
AttributeAdapter(op::v0::LSTMSequence::direction& value)
: EnumAttributeAdapterBase<op::v0::LSTMSequence::direction>(value)
{
}
static constexpr DiscreteTypeInfo type_info{
"AttributeAdapter<op::v0::LSTMSequence::direction>", 1};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
} // namespace ngraph } // namespace ngraph
...@@ -112,6 +112,11 @@ op::v1::Gather::Gather(const Output<Node>& params, ...@@ -112,6 +112,11 @@ op::v1::Gather::Gather(const Output<Node>& params,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v1::Gather::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
void op::v1::Gather::validate_and_infer_types() void op::v1::Gather::validate_and_infer_types()
{ {
const auto& input_rank = get_input_partial_shape(PARAMS).rank(); const auto& input_rank = get_input_partial_shape(PARAMS).rank();
......
...@@ -67,6 +67,7 @@ namespace ngraph ...@@ -67,6 +67,7 @@ namespace ngraph
const Output<Node>& indices, const Output<Node>& indices,
const Output<Node>& axis); const Output<Node>& axis);
bool visit_attributes(AttributeVisitor& visitor) override;
int64_t get_axis() const; int64_t get_axis() const;
void validate_and_infer_types() override; void validate_and_infer_types() override;
......
...@@ -38,6 +38,11 @@ shared_ptr<Node> op::v1::GatherTree::copy_with_new_args(const NodeVector& new_ar ...@@ -38,6 +38,11 @@ shared_ptr<Node> op::v1::GatherTree::copy_with_new_args(const NodeVector& new_ar
new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3)); new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3));
} }
bool ngraph::op::v1::GatherTree::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
void op::v1::GatherTree::validate_and_infer_types() void op::v1::GatherTree::validate_and_infer_types()
{ {
const auto& step_ids_rank = get_input_partial_shape(0); const auto& step_ids_rank = get_input_partial_shape(0);
......
...@@ -44,6 +44,7 @@ namespace ngraph ...@@ -44,6 +44,7 @@ namespace ngraph
const Output<Node>& max_seq_len, const Output<Node>& max_seq_len,
const Output<Node>& end_token); const Output<Node>& end_token);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -28,6 +28,11 @@ op::Log::Log(const Output<Node>& arg) ...@@ -28,6 +28,11 @@ op::Log::Log(const Output<Node>& arg)
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v0::Log::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
shared_ptr<Node> op::Log::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Log::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -37,6 +37,7 @@ namespace ngraph ...@@ -37,6 +37,7 @@ namespace ngraph
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Log(const Output<Node>& arg); Log(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/op/lrn.hpp" #include "ngraph/op/lrn.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
...@@ -111,6 +112,15 @@ void op::LRN::validate_and_infer_types() ...@@ -111,6 +112,15 @@ void op::LRN::validate_and_infer_types()
")."); ").");
} }
bool ngraph::op::v0::LRN::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("alpha", m_alpha);
visitor.on_attribute("beta", m_beta);
visitor.on_attribute("bias", m_bias);
visitor.on_attribute("size", m_size);
return true;
}
shared_ptr<Node> op::LRN::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::LRN::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -58,6 +58,7 @@ namespace ngraph ...@@ -58,6 +58,7 @@ namespace ngraph
double bias, double bias,
size_t size); size_t size);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
......
...@@ -28,6 +28,11 @@ op::v1::LogicalNot::LogicalNot(const Output<Node>& arg) ...@@ -28,6 +28,11 @@ op::v1::LogicalNot::LogicalNot(const Output<Node>& arg)
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v1::LogicalNot::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
// TODO(amprocte): Update this to allow only boolean, for consistency with logical binops. // TODO(amprocte): Update this to allow only boolean, for consistency with logical binops.
void op::v1::LogicalNot::validate_and_infer_types() void op::v1::LogicalNot::validate_and_infer_types()
{ {
......
...@@ -37,6 +37,7 @@ namespace ngraph ...@@ -37,6 +37,7 @@ namespace ngraph
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
LogicalNot(const Output<Node>& arg); LogicalNot(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment