Unverified Commit 61b4e6f9 authored by Tomasz Socha's avatar Tomasz Socha Committed by GitHub

Add attribute visitor for ops E-L (#4346)

* Add new attribute adapters for vector<float> and vector<string> types

* Add attribute visitor to LSTMSequence op

* Add attribute visitor to LSTMCell op

* Add attribute visitor to LRN op

* Add attribute visitor to LogicalNot op

* Add attribute visitor to Log op

* Add attribute visitor to HardSigmoid op

* Add attribute visitor to GRN op

* Add attribute visitor to GroupConvolutionBackpropData op

* Add attribute visitor to GroupConvolution op

* Add attribute visitor to GatherTree op

* Add attribute visitor to Gather op

* Add attribute visitor to Floor op

* Add attribute visitor to FakeQuantize op

* Add attribute visitor to Exp op

* Add attribute visitor to Erf op

* Add attribute visitor to Elu op

* Add test for LRN, LSTMSewuence and LSTMCell

* Add test for Elu

* Add test for FakeQuantize

* Revert user_op test missed

* Add test for GRN

* Fix for CoordinateDiff

* Add tests for GroupConvolution and GroupCinvolutionBackpropData

* Tests alphabetical reorder
Co-authored-by: 's avatarKatarzyna Mitrus <katarzyna.mitrus@intel.com>
parent 5a5579f7
...@@ -242,4 +242,40 @@ namespace ngraph ...@@ -242,4 +242,40 @@ namespace ngraph
m_value = copy_from<vector<uint64_t>>(value); m_value = copy_from<vector<uint64_t>>(value);
m_buffer_valid = false; m_buffer_valid = false;
} }
constexpr DiscreteTypeInfo AttributeAdapter<vector<float>>::type_info;
const vector<float>& AttributeAdapter<vector<float>>::get()
{
if (!m_buffer_valid)
{
m_buffer = copy_from<vector<float>>(m_value);
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<vector<float>>::set(const vector<float>& value)
{
m_value = copy_from<vector<float>>(value);
m_buffer_valid = false;
}
constexpr DiscreteTypeInfo AttributeAdapter<vector<string>>::type_info;
const vector<string>& AttributeAdapter<vector<string>>::get()
{
if (!m_buffer_valid)
{
m_buffer = copy_from<vector<string>>(m_value);
m_buffer_valid = true;
}
return m_buffer;
}
void AttributeAdapter<vector<string>>::set(const vector<string>& value)
{
m_value = copy_from<vector<string>>(value);
m_buffer_valid = false;
}
} }
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#pragma once #pragma once
#include <string>
#include <type_traits> #include <type_traits>
#include <vector> #include <vector>
...@@ -299,6 +300,39 @@ namespace ngraph ...@@ -299,6 +300,39 @@ namespace ngraph
void set(const std::vector<int64_t>& value) override; void set(const std::vector<int64_t>& value) override;
}; };
template <>
class NGRAPH_API AttributeAdapter<std::vector<float>>
: public ValueReference<std::vector<float>>, public ValueAccessor<std::vector<float>>
{
public:
AttributeAdapter(std::vector<float>& value)
: ValueReference<std::vector<float>>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<float>>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
const std::vector<float>& get() override;
void set(const std::vector<float>& value) override;
};
template <>
class NGRAPH_API AttributeAdapter<std::vector<std::string>>
: public ValueReference<std::vector<std::string>>,
public ValueAccessor<std::vector<std::string>>
{
public:
AttributeAdapter(std::vector<std::string>& value)
: ValueReference<std::vector<std::string>>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<vector<string>>", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
const std::vector<std::string>& get() override;
void set(const std::vector<std::string>& value) override;
};
template <typename A, typename B> template <typename A, typename B>
A copy_from(B& b) A copy_from(B& b)
{ {
......
...@@ -46,16 +46,25 @@ namespace ngraph ...@@ -46,16 +46,25 @@ namespace ngraph
{ {
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter)); on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}; };
virtual void on_adapter(const std::string& name, ValueAccessor<int64_t>& adapter)
{
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}
virtual void on_adapter(const std::string& name, ValueAccessor<double>& adapter)
{
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}
virtual void on_adapter(const std::string& name, virtual void on_adapter(const std::string& name,
ValueAccessor<std::vector<int64_t>>& adapter) ValueAccessor<std::vector<int64_t>>& adapter)
{ {
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter)); on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
} }
virtual void on_adapter(const std::string& name, ValueAccessor<int64_t>& adapter) virtual void on_adapter(const std::string& name, ValueAccessor<std::vector<float>>& adapter)
{ {
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter)); on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
} }
virtual void on_adapter(const std::string& name, ValueAccessor<double>& adapter) virtual void on_adapter(const std::string& name,
ValueAccessor<std::vector<std::string>>& adapter)
{ {
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter)); on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
} }
...@@ -68,5 +77,11 @@ namespace ngraph ...@@ -68,5 +77,11 @@ namespace ngraph
AttributeAdapter<T> adapter(value); AttributeAdapter<T> adapter(value);
on_adapter(name, adapter); on_adapter(name, adapter);
} }
void on_attribute(const std::string& name, op::AutoBroadcastSpec& value)
{
AttributeAdapter<op::AutoBroadcastType> adapter(value.m_type);
on_adapter(name, adapter);
}
}; };
} }
...@@ -80,7 +80,7 @@ const vector<int64_t>& AttributeAdapter<CoordinateDiff>::get() ...@@ -80,7 +80,7 @@ const vector<int64_t>& AttributeAdapter<CoordinateDiff>::get()
void AttributeAdapter<CoordinateDiff>::set(const vector<int64_t>& value) void AttributeAdapter<CoordinateDiff>::set(const vector<int64_t>& value)
{ {
m_value = copy_from<CoordinateDiff>(m_value); m_value = copy_from<CoordinateDiff>(value);
m_buffer_valid = false; m_buffer_valid = false;
} }
......
...@@ -23,6 +23,11 @@ using namespace ngraph; ...@@ -23,6 +23,11 @@ using namespace ngraph;
constexpr NodeTypeInfo op::Erf::type_info; constexpr NodeTypeInfo op::Erf::type_info;
bool ngraph::op::v0::Erf::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
shared_ptr<Node> op::Erf::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Erf::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -32,6 +32,7 @@ namespace ngraph ...@@ -32,6 +32,7 @@ namespace ngraph
Erf() = default; Erf() = default;
Erf(const Output<Node>& arg); Erf(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
}; };
......
...@@ -28,6 +28,11 @@ op::Exp::Exp(const Output<Node>& arg) ...@@ -28,6 +28,11 @@ op::Exp::Exp(const Output<Node>& arg)
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v0::Exp::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
shared_ptr<Node> op::Exp::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Exp::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -37,6 +37,7 @@ namespace ngraph ...@@ -37,6 +37,7 @@ namespace ngraph
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Exp(const Output<Node>& arg); Exp(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -27,6 +27,11 @@ op::Floor::Floor(const Output<Node>& arg) ...@@ -27,6 +27,11 @@ op::Floor::Floor(const Output<Node>& arg)
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v0::Floor::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
shared_ptr<Node> op::Floor::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Floor::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -37,6 +37,7 @@ namespace ngraph ...@@ -37,6 +37,7 @@ namespace ngraph
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Floor(const Output<Node>& arg); Floor(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
}; };
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/op/fused/elu.hpp" #include "ngraph/op/fused/elu.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/autobroadcast.hpp" #include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/builder/make_constant.hpp" #include "ngraph/builder/make_constant.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
...@@ -37,6 +38,12 @@ op::Elu::Elu(const Output<Node>& data, const double alpha) ...@@ -37,6 +38,12 @@ op::Elu::Elu(const Output<Node>& data, const double alpha)
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v0::Elu::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("alpha", m_alpha);
return true;
}
NodeVector op::Elu::decompose_op() const NodeVector op::Elu::decompose_op() const
{ {
auto data = input_value(0); auto data = input_value(0);
......
...@@ -42,6 +42,7 @@ namespace ngraph ...@@ -42,6 +42,7 @@ namespace ngraph
/// \param alpha Multiplier for negative values /// \param alpha Multiplier for negative values
Elu(const Output<Node>& data, const double alpha); Elu(const Output<Node>& data, const double alpha);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <memory> #include <memory>
#include "fake_quantize.hpp" #include "fake_quantize.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/autobroadcast.hpp" #include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
...@@ -80,6 +81,13 @@ void op::FakeQuantize::validate_and_infer_types() ...@@ -80,6 +81,13 @@ void op::FakeQuantize::validate_and_infer_types()
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0)); set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
} }
bool ngraph::op::v0::FakeQuantize::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("levels", m_levels);
visitor.on_attribute("auto_broadcast", m_auto_broadcast);
return true;
}
NodeVector op::FakeQuantize::decompose_op() const NodeVector op::FakeQuantize::decompose_op() const
{ {
Output<Node> data{input_value(0)}; Output<Node> data{input_value(0)};
......
...@@ -67,6 +67,7 @@ namespace ngraph ...@@ -67,6 +67,7 @@ namespace ngraph
const AutoBroadcastSpec& auto_broadcast = const AutoBroadcastSpec& auto_broadcast =
AutoBroadcastSpec(AutoBroadcastType::NUMPY)); AutoBroadcastSpec(AutoBroadcastType::NUMPY));
bool visit_attributes(AttributeVisitor& visitor) override;
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
virtual void validate_and_infer_types() override; virtual void validate_and_infer_types() override;
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <iterator> #include <iterator>
#include "grn.hpp" #include "grn.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/axis_set.hpp" #include "ngraph/axis_set.hpp"
#include "ngraph/builder/norm.hpp" #include "ngraph/builder/norm.hpp"
#include "ngraph/builder/reshape.hpp" #include "ngraph/builder/reshape.hpp"
...@@ -36,6 +37,12 @@ op::GRN::GRN(const Output<Node>& data, float bias) ...@@ -36,6 +37,12 @@ op::GRN::GRN(const Output<Node>& data, float bias)
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v0::GRN::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("bias", m_bias);
return true;
}
void op::GRN::pre_validate_and_infer_types() void op::GRN::pre_validate_and_infer_types()
{ {
const auto& data_pshape = get_input_partial_shape(0); const auto& data_pshape = get_input_partial_shape(0);
......
...@@ -42,6 +42,7 @@ namespace ngraph ...@@ -42,6 +42,7 @@ namespace ngraph
/// ///
GRN(const Output<Node>& data, float bias); GRN(const Output<Node>& data, float bias);
bool visit_attributes(AttributeVisitor& visitor) override;
float get_bias() const { return m_bias; } float get_bias() const { return m_bias; }
virtual void pre_validate_and_infer_types() override; virtual void pre_validate_and_infer_types() override;
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "group_conv.hpp" #include "group_conv.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/reshape.hpp" #include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp" #include "ngraph/builder/split.hpp"
#include "ngraph/op/concat.hpp" #include "ngraph/op/concat.hpp"
...@@ -57,6 +58,16 @@ op::v1::GroupConvolution::GroupConvolution(const Output<Node>& data_batch, ...@@ -57,6 +58,16 @@ op::v1::GroupConvolution::GroupConvolution(const Output<Node>& data_batch,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v1::GroupConvolution::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("strides", m_strides);
visitor.on_attribute("pads_begin", m_pads_begin);
visitor.on_attribute("pads_end", m_pads_end);
visitor.on_attribute("dilations", m_dilations);
visitor.on_attribute("auto_pad", m_auto_pad);
return true;
}
void op::v1::GroupConvolution::validate_and_infer_types() void op::v1::GroupConvolution::validate_and_infer_types()
{ {
const PartialShape& data_batch_pshape = get_input_partial_shape(0); const PartialShape& data_batch_pshape = get_input_partial_shape(0);
...@@ -219,6 +230,17 @@ op::v1::GroupConvolutionBackpropData::GroupConvolutionBackpropData( ...@@ -219,6 +230,17 @@ op::v1::GroupConvolutionBackpropData::GroupConvolutionBackpropData(
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v1::GroupConvolutionBackpropData::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("strides", m_strides);
visitor.on_attribute("pads_begin", m_pads_begin);
visitor.on_attribute("pads_end", m_pads_end);
visitor.on_attribute("dilations", m_dilations);
visitor.on_attribute("auto_pad", m_auto_pad);
visitor.on_attribute("output_padding", m_output_padding);
return true;
}
bool op::v1::GroupConvolutionBackpropData::is_dynamic() const bool op::v1::GroupConvolutionBackpropData::is_dynamic() const
{ {
bool is_dynamic = Node::is_dynamic(); bool is_dynamic = Node::is_dynamic();
......
...@@ -61,6 +61,8 @@ namespace ngraph ...@@ -61,6 +61,8 @@ namespace ngraph
const CoordinateDiff& pads_end, const CoordinateDiff& pads_end,
const Strides& dilations, const Strides& dilations,
const PadType& auto_pad = PadType::EXPLICIT); const PadType& auto_pad = PadType::EXPLICIT);
bool visit_attributes(AttributeVisitor& visitor) override;
// TODO - Remove supports_decompose and validate_and_infer_type once op supports // TODO - Remove supports_decompose and validate_and_infer_type once op supports
// decomposition // decomposition
bool supports_decompose() const override { return false; } bool supports_decompose() const override { return false; }
...@@ -187,6 +189,7 @@ namespace ngraph ...@@ -187,6 +189,7 @@ namespace ngraph
const PadType& auto_pad = PadType::EXPLICIT, const PadType& auto_pad = PadType::EXPLICIT,
const CoordinateDiff& output_padding = {}); const CoordinateDiff& output_padding = {});
bool visit_attributes(AttributeVisitor& visitor) override;
virtual bool is_dynamic() const override; virtual bool is_dynamic() const override;
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
virtual void pre_validate_and_infer_types() override; virtual void pre_validate_and_infer_types() override;
......
...@@ -37,6 +37,11 @@ op::HardSigmoid::HardSigmoid(const Output<Node>& data, ...@@ -37,6 +37,11 @@ op::HardSigmoid::HardSigmoid(const Output<Node>& data,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v0::HardSigmoid::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
void op::HardSigmoid::pre_validate_and_infer_types() void op::HardSigmoid::pre_validate_and_infer_types()
{ {
const auto& alpha_pshape = get_input_partial_shape(1); const auto& alpha_pshape = get_input_partial_shape(1);
......
...@@ -46,6 +46,7 @@ namespace ngraph ...@@ -46,6 +46,7 @@ namespace ngraph
const Output<Node>& alpha, const Output<Node>& alpha,
const Output<Node>& beta); const Output<Node>& beta);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual void pre_validate_and_infer_types() override; virtual void pre_validate_and_infer_types() override;
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <cmath> #include <cmath>
#include <functional> #include <functional>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/reshape.hpp" #include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp" #include "ngraph/builder/split.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
...@@ -107,6 +108,19 @@ op::LSTMCell::LSTMCell(const Output<Node>& X, ...@@ -107,6 +108,19 @@ op::LSTMCell::LSTMCell(const Output<Node>& X,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v0::LSTMCell::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("hidden_size", m_hidden_size);
visitor.on_attribute("activations", m_activations);
visitor.on_attribute("activations_alpha", m_activations_alpha);
visitor.on_attribute("activations_beta", m_activations_beta);
visitor.on_attribute("clip", m_clip);
visitor.on_attribute("input_forget", m_input_forget);
visitor.on_attribute("weights_format", m_weights_format);
return true;
}
void op::LSTMCell::pre_validate_and_infer_types() void op::LSTMCell::pre_validate_and_infer_types()
{ {
set_output_size(2); set_output_size(2);
...@@ -386,3 +400,26 @@ shared_ptr<Node> op::LSTMCell::copy_with_new_args(const NodeVector& new_args) co ...@@ -386,3 +400,26 @@ shared_ptr<Node> op::LSTMCell::copy_with_new_args(const NodeVector& new_args) co
throw ngraph_error("Incorrect number of new arguments"); throw ngraph_error("Incorrect number of new arguments");
} }
} }
namespace ngraph
{
template <>
EnumNames<op::LSTMWeightsFormat>& EnumNames<op::LSTMWeightsFormat>::get()
{
static auto enum_names =
EnumNames<op::LSTMWeightsFormat>("op::LSTMWeightsFormat",
{{"fico", op::LSTMWeightsFormat::FICO},
{"icof", op::LSTMWeightsFormat::ICOF},
{"ifco", op::LSTMWeightsFormat::IFCO},
{"ifoc", op::LSTMWeightsFormat::IFOC},
{"iofc", op::LSTMWeightsFormat::IOFC}});
return enum_names;
}
constexpr DiscreteTypeInfo AttributeAdapter<op::LSTMWeightsFormat>::type_info;
std::ostream& operator<<(std::ostream& s, const op::LSTMWeightsFormat& type)
{
return s << as_string(type);
}
} // namespace ngraph
...@@ -224,6 +224,7 @@ namespace ngraph ...@@ -224,6 +224,7 @@ namespace ngraph
float clip = 0.f, float clip = 0.f,
bool input_forget = false); bool input_forget = false);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual void pre_validate_and_infer_types() override; virtual void pre_validate_and_infer_types() override;
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
...@@ -284,4 +285,20 @@ namespace ngraph ...@@ -284,4 +285,20 @@ namespace ngraph
} }
using v0::LSTMCell; using v0::LSTMCell;
} // namespace op } // namespace op
std::ostream& operator<<(std::ostream& s, const op::LSTMWeightsFormat& type);
template <>
class NGRAPH_API AttributeAdapter<op::LSTMWeightsFormat>
: public EnumAttributeAdapterBase<op::LSTMWeightsFormat>
{
public:
AttributeAdapter(op::LSTMWeightsFormat& value)
: EnumAttributeAdapterBase<op::LSTMWeightsFormat>(value)
{
}
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<op::LSTMWeightsFormat>", 1};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
} // namespace ngraph } // namespace ngraph
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "ngraph/op/fused/lstm_sequence.hpp" #include "ngraph/op/fused/lstm_sequence.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/autobroadcast.hpp" #include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/builder/reshape.hpp" #include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp" #include "ngraph/builder/split.hpp"
...@@ -32,6 +33,19 @@ using namespace ngraph; ...@@ -32,6 +33,19 @@ using namespace ngraph;
using namespace std; using namespace std;
constexpr NodeTypeInfo op::LSTMSequence::type_info; constexpr NodeTypeInfo op::LSTMSequence::type_info;
bool ngraph::op::v0::LSTMSequence::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("hidden_size", m_hidden_size);
visitor.on_attribute("activations", m_activations);
visitor.on_attribute("activations_alpha", m_activations_alpha);
visitor.on_attribute("activations_beta", m_activations_beta);
visitor.on_attribute("clip", m_clip_threshold);
visitor.on_attribute("direction", m_direction);
visitor.on_attribute("input_forget", m_input_forget);
visitor.on_attribute("weights_format", m_weights_format);
return true;
}
NodeVector op::LSTMSequence::decompose_op() const NodeVector op::LSTMSequence::decompose_op() const
{ {
NodeVector results; NodeVector results;
...@@ -247,3 +261,24 @@ shared_ptr<Node> op::LSTMSequence::prepare_input(Output<Node> node, bool is_reve ...@@ -247,3 +261,24 @@ shared_ptr<Node> op::LSTMSequence::prepare_input(Output<Node> node, bool is_reve
// Since we have forward LSTM we can squeeze `num_directions` axis from inputs. // Since we have forward LSTM we can squeeze `num_directions` axis from inputs.
return builder::squeeze(tmp); return builder::squeeze(tmp);
} }
namespace ngraph
{
template <>
EnumNames<op::v0::LSTMSequence::direction>& EnumNames<op::v0::LSTMSequence::direction>::get()
{
static auto enum_names = EnumNames<op::v0::LSTMSequence::direction>(
"op::v0::LSTMSequence::direction",
{{"forward", op::v0::LSTMSequence::direction::FORWARD},
{"reverse", op::v0::LSTMSequence::direction::REVERSE},
{"bidirectional", op::v0::LSTMSequence::direction::BIDIRECTIONAL}});
return enum_names;
}
constexpr DiscreteTypeInfo AttributeAdapter<op::v0::LSTMSequence::direction>::type_info;
std::ostream& operator<<(std::ostream& s, const op::v0::LSTMSequence::direction& type)
{
return s << as_string(type);
}
} // namespace ngraph
...@@ -135,6 +135,7 @@ namespace ngraph ...@@ -135,6 +135,7 @@ namespace ngraph
{ {
} }
bool visit_attributes(AttributeVisitor& visitor) override;
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
...@@ -185,4 +186,21 @@ namespace ngraph ...@@ -185,4 +186,21 @@ namespace ngraph
} }
using v0::LSTMSequence; using v0::LSTMSequence;
} // namespace op } // namespace op
std::ostream& operator<<(std::ostream& s, const op::v0::LSTMSequence::direction& type);
template <>
class NGRAPH_API AttributeAdapter<op::v0::LSTMSequence::direction>
: public EnumAttributeAdapterBase<op::v0::LSTMSequence::direction>
{
public:
AttributeAdapter(op::v0::LSTMSequence::direction& value)
: EnumAttributeAdapterBase<op::v0::LSTMSequence::direction>(value)
{
}
static constexpr DiscreteTypeInfo type_info{
"AttributeAdapter<op::v0::LSTMSequence::direction>", 1};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
} // namespace ngraph } // namespace ngraph
...@@ -112,6 +112,11 @@ op::v1::Gather::Gather(const Output<Node>& params, ...@@ -112,6 +112,11 @@ op::v1::Gather::Gather(const Output<Node>& params,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v1::Gather::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
void op::v1::Gather::validate_and_infer_types() void op::v1::Gather::validate_and_infer_types()
{ {
const auto& input_rank = get_input_partial_shape(PARAMS).rank(); const auto& input_rank = get_input_partial_shape(PARAMS).rank();
......
...@@ -67,6 +67,7 @@ namespace ngraph ...@@ -67,6 +67,7 @@ namespace ngraph
const Output<Node>& indices, const Output<Node>& indices,
const Output<Node>& axis); const Output<Node>& axis);
bool visit_attributes(AttributeVisitor& visitor) override;
int64_t get_axis() const; int64_t get_axis() const;
void validate_and_infer_types() override; void validate_and_infer_types() override;
......
...@@ -38,6 +38,11 @@ shared_ptr<Node> op::v1::GatherTree::copy_with_new_args(const NodeVector& new_ar ...@@ -38,6 +38,11 @@ shared_ptr<Node> op::v1::GatherTree::copy_with_new_args(const NodeVector& new_ar
new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3)); new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3));
} }
bool ngraph::op::v1::GatherTree::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
void op::v1::GatherTree::validate_and_infer_types() void op::v1::GatherTree::validate_and_infer_types()
{ {
const auto& step_ids_rank = get_input_partial_shape(0); const auto& step_ids_rank = get_input_partial_shape(0);
......
...@@ -44,6 +44,7 @@ namespace ngraph ...@@ -44,6 +44,7 @@ namespace ngraph
const Output<Node>& max_seq_len, const Output<Node>& max_seq_len,
const Output<Node>& end_token); const Output<Node>& end_token);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -28,6 +28,11 @@ op::Log::Log(const Output<Node>& arg) ...@@ -28,6 +28,11 @@ op::Log::Log(const Output<Node>& arg)
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v0::Log::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
shared_ptr<Node> op::Log::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Log::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -37,6 +37,7 @@ namespace ngraph ...@@ -37,6 +37,7 @@ namespace ngraph
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Log(const Output<Node>& arg); Log(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/op/lrn.hpp" #include "ngraph/op/lrn.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
...@@ -111,6 +112,15 @@ void op::LRN::validate_and_infer_types() ...@@ -111,6 +112,15 @@ void op::LRN::validate_and_infer_types()
")."); ").");
} }
bool ngraph::op::v0::LRN::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("alpha", m_alpha);
visitor.on_attribute("beta", m_beta);
visitor.on_attribute("bias", m_bias);
visitor.on_attribute("size", m_size);
return true;
}
shared_ptr<Node> op::LRN::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::LRN::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -58,6 +58,7 @@ namespace ngraph ...@@ -58,6 +58,7 @@ namespace ngraph
double bias, double bias,
size_t size); size_t size);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
......
...@@ -28,6 +28,11 @@ op::v1::LogicalNot::LogicalNot(const Output<Node>& arg) ...@@ -28,6 +28,11 @@ op::v1::LogicalNot::LogicalNot(const Output<Node>& arg)
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v1::LogicalNot::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
// TODO(amprocte): Update this to allow only boolean, for consistency with logical binops. // TODO(amprocte): Update this to allow only boolean, for consistency with logical binops.
void op::v1::LogicalNot::validate_and_infer_types() void op::v1::LogicalNot::validate_and_infer_types()
{ {
......
...@@ -37,6 +37,7 @@ namespace ngraph ...@@ -37,6 +37,7 @@ namespace ngraph
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
LogicalNot(const Output<Node>& arg); LogicalNot(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "ngraph/opsets/opset1.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -140,16 +141,26 @@ public: ...@@ -140,16 +141,26 @@ public:
double get_double(const string& name) { return m_doubles.at(name); } double get_double(const string& name) { return m_doubles.at(name); }
int64_t get_signed(const string& name) { return m_signeds.at(name); } int64_t get_signed(const string& name) { return m_signeds.at(name); }
uint64_t get_unsigned(const string& name) { return m_unsigneds.at(name); } uint64_t get_unsigned(const string& name) { return m_unsigneds.at(name); }
vector<float>& get_float_vector(const string& name) { return m_float_vectors.at(name); }
vector<int64_t>& get_signed_vector(const string& name) { return m_signed_vectors.at(name); } vector<int64_t>& get_signed_vector(const string& name) { return m_signed_vectors.at(name); }
vector<string>& get_string_vector(const string& name) { return m_string_vectors.at(name); }
void set_string(const string& name, const string& value) { m_strings[name] = value; } void set_string(const string& name, const string& value) { m_strings[name] = value; }
void set_bool(const string& name, bool value) { m_bools[name] = value; } void set_bool(const string& name, bool value) { m_bools[name] = value; }
void set_double(const string& name, double value) { m_doubles[name] = value; } void set_double(const string& name, double value) { m_doubles[name] = value; }
void set_signed(const string& name, int64_t value) { m_signeds[name] = value; } void set_signed(const string& name, int64_t value) { m_signeds[name] = value; }
void set_unsigned(const string& name, uint64_t value) { m_unsigneds[name] = value; } void set_unsigned(const string& name, uint64_t value) { m_unsigneds[name] = value; }
void set_float_vector(const string& name, const vector<float>& value)
{
m_float_vectors[name] = value;
}
void set_signed_vector(const string& name, const vector<int64_t>& value) void set_signed_vector(const string& name, const vector<int64_t>& value)
{ {
m_signed_vectors[name] = value; m_signed_vectors[name] = value;
} }
void set_string_vector(const string& name, const vector<string>& value)
{
m_string_vectors[name] = value;
}
void on_attribute(const string& name, string& value) override { set_string(name, value); }; void on_attribute(const string& name, string& value) override { set_string(name, value); };
void on_attribute(const string& name, bool& value) override { set_bool(name, value); } void on_attribute(const string& name, bool& value) override { set_bool(name, value); }
...@@ -162,10 +173,6 @@ public: ...@@ -162,10 +173,6 @@ public:
{ {
set_string(name, adapter.get()); set_string(name, adapter.get());
}; };
void on_adapter(const string& name, ValueAccessor<vector<int64_t>>& adapter) override
{
set_signed_vector(name, adapter.get());
}
void on_adapter(const string& name, ValueAccessor<int64_t>& adapter) override void on_adapter(const string& name, ValueAccessor<int64_t>& adapter) override
{ {
set_signed(name, adapter.get()); set_signed(name, adapter.get());
...@@ -174,6 +181,18 @@ public: ...@@ -174,6 +181,18 @@ public:
{ {
set_double(name, adapter.get()); set_double(name, adapter.get());
} }
void on_adapter(const string& name, ValueAccessor<vector<float>>& adapter) override
{
set_float_vector(name, adapter.get());
}
void on_adapter(const string& name, ValueAccessor<vector<int64_t>>& adapter) override
{
set_signed_vector(name, adapter.get());
}
void on_adapter(const string& name, ValueAccessor<vector<string>>& adapter) override
{
set_string_vector(name, adapter.get());
}
protected: protected:
NodeTypeInfo m_node_type_info; NodeTypeInfo m_node_type_info;
...@@ -183,6 +202,8 @@ protected: ...@@ -183,6 +202,8 @@ protected:
map<string, int64_t> m_signeds; map<string, int64_t> m_signeds;
map<string, uint64_t> m_unsigneds; map<string, uint64_t> m_unsigneds;
map<string, vector<int64_t>> m_signed_vectors; map<string, vector<int64_t>> m_signed_vectors;
map<string, vector<float>> m_float_vectors;
map<string, vector<std::string>> m_string_vectors;
}; };
class NodeBuilder : public AttributeVisitor class NodeBuilder : public AttributeVisitor
...@@ -197,7 +218,6 @@ public: ...@@ -197,7 +218,6 @@ public:
{ {
shared_ptr<Node> node(FactoryRegistry<Node>::get().create(m_values.get_node_type_info())); shared_ptr<Node> node(FactoryRegistry<Node>::get().create(m_values.get_node_type_info()));
node->visit_attributes(*this); node->visit_attributes(*this);
node->validate_and_infer_types();
return node; return node;
} }
...@@ -215,10 +235,6 @@ public: ...@@ -215,10 +235,6 @@ public:
{ {
adapter.set(m_values.get_string(name)); adapter.set(m_values.get_string(name));
}; };
void on_adapter(const string& name, ValueAccessor<vector<int64_t>>& adapter) override
{
adapter.set(m_values.get_signed_vector(name));
}
void on_adapter(const string& name, ValueAccessor<int64_t>& adapter) override void on_adapter(const string& name, ValueAccessor<int64_t>& adapter) override
{ {
adapter.set(m_values.get_signed(name)); adapter.set(m_values.get_signed(name));
...@@ -227,6 +243,18 @@ public: ...@@ -227,6 +243,18 @@ public:
{ {
adapter.set(m_values.get_double(name)); adapter.set(m_values.get_double(name));
} }
void on_adapter(const string& name, ValueAccessor<vector<int64_t>>& adapter) override
{
adapter.set(m_values.get_signed_vector(name));
}
void on_adapter(const string& name, ValueAccessor<vector<string>>& adapter) override
{
adapter.set(m_values.get_string_vector(name));
}
void on_adapter(const string& name, ValueAccessor<vector<float>>& adapter) override
{
adapter.set(m_values.get_float_vector(name));
}
protected: protected:
NodeSaver m_values; NodeSaver m_values;
...@@ -256,3 +284,217 @@ TEST(attributes, user_op) ...@@ -256,3 +284,217 @@ TEST(attributes, user_op)
EXPECT_EQ(g_oracle->get_hyper_parameters(), oracle->get_hyper_parameters()); EXPECT_EQ(g_oracle->get_hyper_parameters(), oracle->get_hyper_parameters());
EXPECT_EQ(g_oracle->get_ultra_parameters(), oracle->get_ultra_parameters()); EXPECT_EQ(g_oracle->get_ultra_parameters(), oracle->get_ultra_parameters());
} }
TEST(attributes, elu_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::Elu>();
auto data = make_shared<op::Parameter>(element::f32, Shape{2, 4});
double alpha = 0.1;
const auto elu = make_shared<opset1::Elu>(data, alpha);
NodeBuilder builder(elu);
auto g_elu = as_type_ptr<opset1::Elu>(builder.create());
EXPECT_EQ(g_elu->get_alpha(), elu->get_alpha());
}
TEST(attributes, fake_quantize_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::FakeQuantize>();
const auto data = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3, 4});
const auto input_low = make_shared<op::Parameter>(element::f32, Shape{});
const auto input_high = make_shared<op::Parameter>(element::f32, Shape{});
const auto output_low = make_shared<op::Parameter>(element::f32, Shape{});
const auto output_high = make_shared<op::Parameter>(element::f32, Shape{});
auto levels = 5;
auto auto_broadcast = op::AutoBroadcastType::NUMPY;
const auto fake_quantize = make_shared<op::FakeQuantize>(
data, input_low, input_high, output_low, output_high, levels, auto_broadcast);
NodeBuilder builder(fake_quantize);
auto g_fake_quantize = as_type_ptr<opset1::FakeQuantize>(builder.create());
EXPECT_EQ(g_fake_quantize->get_levels(), fake_quantize->get_levels());
EXPECT_EQ(g_fake_quantize->get_auto_broadcast(), fake_quantize->get_auto_broadcast());
}
TEST(attributes, grn_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::GRN>();
auto data = make_shared<op::Parameter>(element::f32, Shape{2, 3, 4, 5});
float bias = 1.25f;
auto grn = make_shared<opset1::GRN>(data, bias);
NodeBuilder builder(grn);
auto g_grn = as_type_ptr<opset1::GRN>(builder.create());
EXPECT_EQ(g_grn->get_bias(), grn->get_bias());
}
TEST(attributes, group_conv_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::GroupConvolution>();
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 12, 224, 224});
auto filters = make_shared<op::Parameter>(element::f32, Shape{4, 1, 3, 5, 5});
auto strides = Strides{1, 1};
auto pads_begin = CoordinateDiff{1, 2};
auto pads_end = CoordinateDiff{1, 2};
auto dilations = Strides{1, 1};
auto group_conv = make_shared<opset1::GroupConvolution>(
data, filters, strides, pads_begin, pads_end, dilations, op::PadType::VALID);
NodeBuilder builder(group_conv);
auto g_group_conv = as_type_ptr<opset1::GroupConvolution>(builder.create());
EXPECT_EQ(g_group_conv->get_strides(), group_conv->get_strides());
EXPECT_EQ(g_group_conv->get_pads_begin(), group_conv->get_pads_begin());
EXPECT_EQ(g_group_conv->get_pads_end(), group_conv->get_pads_end());
EXPECT_EQ(g_group_conv->get_dilations(), group_conv->get_dilations());
EXPECT_EQ(g_group_conv->get_auto_pad(), group_conv->get_auto_pad());
}
TEST(attributes, group_conv_backprop_data_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::GroupConvolutionBackpropData>();
const auto data = make_shared<op::Parameter>(element::f32, Shape{1, 20, 224, 224});
const auto filter = make_shared<op::Parameter>(element::f32, Shape{4, 5, 2, 3, 3});
const auto output_shape = make_shared<op::Parameter>(element::f32, Shape{1, 8, 447, 447});
const auto strides = Strides{2, 1};
const auto pads_begin = CoordinateDiff{3, 4};
const auto pads_end = CoordinateDiff{4, 6};
const auto dilations = Strides{3, 1};
const auto auto_pad = op::PadType::EXPLICIT;
const auto output_padding = CoordinateDiff{3, 4};
const auto gcbd = make_shared<opset1::GroupConvolutionBackpropData>(data,
filter,
output_shape,
strides,
pads_begin,
pads_end,
dilations,
auto_pad,
output_padding);
NodeBuilder builder(gcbd);
const auto g_gcbd = as_type_ptr<opset1::GroupConvolutionBackpropData>(builder.create());
EXPECT_EQ(g_gcbd->get_strides(), gcbd->get_strides());
EXPECT_EQ(g_gcbd->get_pads_begin(), gcbd->get_pads_begin());
EXPECT_EQ(g_gcbd->get_pads_end(), gcbd->get_pads_end());
EXPECT_EQ(g_gcbd->get_dilations(), gcbd->get_dilations());
EXPECT_EQ(g_gcbd->get_auto_pad(), gcbd->get_auto_pad());
EXPECT_EQ(g_gcbd->get_output_padding(), gcbd->get_output_padding());
}
TEST(attributes, lrn_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::LRN>();
const auto arg = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3, 4});
const auto axes = make_shared<op::Parameter>(element::i32, Shape{2});
const double alpha = 1.1;
const double beta = 2.2;
const double bias = 3.3;
const size_t size = 4;
const auto lrn = make_shared<opset1::LRN>(arg, axes, alpha, beta, bias, size);
NodeBuilder builder(lrn);
auto g_lrn = as_type_ptr<opset1::LRN>(builder.create());
EXPECT_EQ(g_lrn->get_alpha(), lrn->get_alpha());
EXPECT_EQ(g_lrn->get_beta(), lrn->get_beta());
EXPECT_EQ(g_lrn->get_bias(), lrn->get_bias());
EXPECT_EQ(g_lrn->get_nsize(), lrn->get_nsize());
}
TEST(attributes, lstm_cell_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::LSTMCell>();
auto X = make_shared<op::Parameter>(element::f32, Shape{2, 3});
auto H = make_shared<op::Parameter>(element::f32, Shape{2, 3});
auto W = make_shared<op::Parameter>(element::f32, Shape{12, 3});
auto R = make_shared<op::Parameter>(element::f32, Shape{12, 3});
const auto initial_hidden_state = make_shared<op::Parameter>(element::f32, Shape{2, 3});
const auto initial_cell_state = make_shared<op::Parameter>(element::f32, Shape{2, 3});
const auto hidden_size = 3;
const auto weights_format = op::LSTMWeightsFormat::ICOF;
const std::vector<std::string> activations = {"tanh", "sigmoid", "tanh"};
auto activations_alpha = std::vector<float>{1.0, 1.5};
auto activations_beta = std::vector<float>{2.0, 1.0};
const float clip = 0.5f;
bool input_forget = true;
const auto lstm_cell = make_shared<opset1::LSTMCell>(X,
initial_hidden_state,
initial_cell_state,
W,
R,
hidden_size,
weights_format,
activations,
activations_alpha,
activations_beta,
clip,
input_forget);
NodeBuilder builder(lstm_cell);
auto g_lstm_cell = as_type_ptr<opset1::LSTMCell>(builder.create());
EXPECT_EQ(g_lstm_cell->get_hidden_size(), lstm_cell->get_hidden_size());
EXPECT_EQ(g_lstm_cell->get_activations(), lstm_cell->get_activations());
EXPECT_EQ(g_lstm_cell->get_activations_alpha(), lstm_cell->get_activations_alpha());
EXPECT_EQ(g_lstm_cell->get_activations_beta(), lstm_cell->get_activations_beta());
EXPECT_EQ(g_lstm_cell->get_clip(), lstm_cell->get_clip());
EXPECT_EQ(g_lstm_cell->get_input_forget(), lstm_cell->get_input_forget());
EXPECT_EQ(g_lstm_cell->get_weights_format(), lstm_cell->get_weights_format());
}
TEST(attributes, lstm_sequence_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::LSTMSequence>();
const auto X = make_shared<op::Parameter>(element::f32, Shape{1, 2, 4});
const auto initial_hidden_state = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
const auto initial_cell_state = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
const auto sequence_lengths = make_shared<op::Parameter>(element::i32, Shape{2});
const auto W = make_shared<op::Parameter>(element::f32, Shape{1, 12, 4});
const auto R = make_shared<op::Parameter>(element::f32, Shape{1, 12, 3});
const auto B = make_shared<op::Parameter>(element::f32, Shape{1, 12});
const auto hidden_size = 3;
const auto lstm_direction = op::LSTMSequence::direction::FORWARD;
const auto weights_format = op::LSTMWeightsFormat::ICOF;
const std::vector<float> activations_alpha = {1, 2, 3};
const std::vector<float> activations_beta = {4, 5, 6};
const std::vector<std::string> activations = {"tanh", "sigmoid", "tanh"};
const float clip_threshold = 0.5f;
const bool input_forget = true;
const auto lstm_sequence = make_shared<opset1::LSTMSequence>(X,
initial_hidden_state,
initial_cell_state,
sequence_lengths,
W,
R,
B,
hidden_size,
lstm_direction,
weights_format,
activations_alpha,
activations_beta,
activations,
clip_threshold,
input_forget);
NodeBuilder builder(lstm_sequence);
auto g_lstm_sequence = as_type_ptr<opset1::LSTMSequence>(builder.create());
EXPECT_EQ(g_lstm_sequence->get_hidden_size(), lstm_sequence->get_hidden_size());
EXPECT_EQ(g_lstm_sequence->get_activations(), lstm_sequence->get_activations());
EXPECT_EQ(g_lstm_sequence->get_activations_alpha(), lstm_sequence->get_activations_alpha());
EXPECT_EQ(g_lstm_sequence->get_activations_beta(), lstm_sequence->get_activations_beta());
EXPECT_EQ(g_lstm_sequence->get_clip_threshold(), lstm_sequence->get_clip_threshold());
EXPECT_EQ(g_lstm_sequence->get_direction(), lstm_sequence->get_direction());
EXPECT_EQ(g_lstm_sequence->get_input_forget(), lstm_sequence->get_input_forget());
EXPECT_EQ(g_lstm_sequence->get_weights_format(), lstm_sequence->get_weights_format());
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment