Unverified Commit 01698d7a authored by Tomasz Socha's avatar Tomasz Socha Committed by GitHub

Add attribute visitor for ops M-P (#4344)

parent 5d8c39e9
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "psroi_pooling.hpp" #include "psroi_pooling.hpp"
#include "ngraph/attribute_visitor.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -40,6 +41,17 @@ op::PSROIPooling::PSROIPooling(const Output<Node>& input, ...@@ -40,6 +41,17 @@ op::PSROIPooling::PSROIPooling(const Output<Node>& input,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v0::PSROIPooling::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("output_dim", m_output_dim);
visitor.on_attribute("group_size", m_group_size);
visitor.on_attribute("spatial_scale", m_spatial_scale);
visitor.on_attribute("mode", m_mode);
visitor.on_attribute("spatial_bins_x", m_spatial_bins_x);
visitor.on_attribute("spatial_bins_y", m_spatial_bins_y);
return true;
}
void op::PSROIPooling::validate_and_infer_types() void op::PSROIPooling::validate_and_infer_types()
{ {
auto input_et = get_input_element_type(0); auto input_et = get_input_element_type(0);
......
...@@ -51,6 +51,7 @@ namespace ngraph ...@@ -51,6 +51,7 @@ namespace ngraph
int spatial_bins_y, int spatial_bins_y,
const std::string& mode); const std::string& mode);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <numeric> #include <numeric>
#include "matmul.hpp" #include "matmul.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/matmul_factory.hpp" #include "ngraph/builder/matmul_factory.hpp"
#include "ngraph/builder/reshape.hpp" #include "ngraph/builder/reshape.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
...@@ -37,6 +38,13 @@ op::MatMul::MatMul(const Output<Node>& A, ...@@ -37,6 +38,13 @@ op::MatMul::MatMul(const Output<Node>& A,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v0::MatMul::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("transpose_a", m_transpose_a);
visitor.on_attribute("transpose_b", m_transpose_b);
return true;
}
void op::MatMul::pre_validate_and_infer_types() void op::MatMul::pre_validate_and_infer_types()
{ {
element::Type result_et; element::Type result_et;
......
...@@ -44,6 +44,7 @@ namespace ngraph ...@@ -44,6 +44,7 @@ namespace ngraph
const bool& transpose_a = 0, const bool& transpose_a = 0,
const bool& transpose_b = 0); const bool& transpose_b = 0);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual void pre_validate_and_infer_types() override; virtual void pre_validate_and_infer_types() override;
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
// limitations under the License. // limitations under the License.
//***************************************************************************** //*****************************************************************************
#include "ngraph/op/fused/mod.hpp" #include "ngraph/op/fused/mod.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/make_constant.hpp" #include "ngraph/builder/make_constant.hpp"
#include "ngraph/op/abs.hpp" #include "ngraph/op/abs.hpp"
#include "ngraph/op/convert.hpp" #include "ngraph/op/convert.hpp"
...@@ -35,6 +36,12 @@ op::v1::Mod::Mod(const Output<Node>& A, ...@@ -35,6 +36,12 @@ op::v1::Mod::Mod(const Output<Node>& A,
{ {
} }
bool ngraph::op::v1::Mod::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("auto_broadcast", m_auto_broadcast);
return true;
}
NodeVector op::v1::Mod::decompose_op() const NodeVector op::v1::Mod::decompose_op() const
{ {
const auto dividend = make_shared<op::Abs>(input_value(0)); const auto dividend = make_shared<op::Abs>(input_value(0));
......
...@@ -43,6 +43,7 @@ namespace ngraph ...@@ -43,6 +43,7 @@ namespace ngraph
const Output<Node>& B, const Output<Node>& B,
const AutoBroadcastSpec& auto_broadcast = AutoBroadcastType::NUMPY); const AutoBroadcastSpec& auto_broadcast = AutoBroadcastType::NUMPY);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <algorithm> #include <algorithm>
#include <iterator> #include <iterator>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/norm.hpp" #include "ngraph/builder/norm.hpp"
#include "ngraph/builder/reshape.hpp" #include "ngraph/builder/reshape.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
...@@ -39,6 +40,13 @@ op::NormalizeL2::NormalizeL2(const Output<Node>& data, ...@@ -39,6 +40,13 @@ op::NormalizeL2::NormalizeL2(const Output<Node>& data,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v0::NormalizeL2::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("eps", m_eps);
visitor.on_attribute("eps_mode", m_eps_mode);
return true;
}
void op::NormalizeL2::pre_validate_and_infer_types() void op::NormalizeL2::pre_validate_and_infer_types()
{ {
auto axes_node = input_value(1).get_node_shared_ptr(); auto axes_node = input_value(1).get_node_shared_ptr();
......
...@@ -52,6 +52,7 @@ namespace ngraph ...@@ -52,6 +52,7 @@ namespace ngraph
float eps, float eps,
EpsMode eps_mode); EpsMode eps_mode);
bool visit_attributes(AttributeVisitor& visitor) override;
float get_eps() const { return m_eps; } float get_eps() const { return m_eps; }
EpsMode get_eps_mode() const { return m_eps_mode; } EpsMode get_eps_mode() const { return m_eps_mode; }
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
......
...@@ -35,6 +35,11 @@ op::PRelu::PRelu(const Output<Node>& data, const Output<Node>& slope) ...@@ -35,6 +35,11 @@ op::PRelu::PRelu(const Output<Node>& data, const Output<Node>& slope)
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v0::PRelu::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
NodeVector op::PRelu::decompose_op() const NodeVector op::PRelu::decompose_op() const
{ {
auto data = input_value(0); auto data = input_value(0);
......
...@@ -42,6 +42,7 @@ namespace ngraph ...@@ -42,6 +42,7 @@ namespace ngraph
/// \param slope Multipliers for negative values /// \param slope Multipliers for negative values
PRelu(const Output<Node>& data, const Output<Node>& slope); PRelu(const Output<Node>& data, const Output<Node>& slope);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual NodeVector decompose_op() const override; virtual NodeVector decompose_op() const override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/op/max_pool.hpp" #include "ngraph/op/max_pool.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/add.hpp" #include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/validation_util.hpp" #include "ngraph/validation_util.hpp"
...@@ -303,6 +304,17 @@ op::v1::MaxPool::MaxPool(const Output<Node>& arg, ...@@ -303,6 +304,17 @@ op::v1::MaxPool::MaxPool(const Output<Node>& arg,
{ {
} }
bool ngraph::op::v1::MaxPool::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("strides", m_strides);
visitor.on_attribute("pads_begin", m_pads_begin);
visitor.on_attribute("pads_end", m_pads_end);
visitor.on_attribute("kernel", m_kernel);
visitor.on_attribute("rounding_type", m_rounding_type);
visitor.on_attribute("auto_pad", m_auto_pad);
return true;
}
void op::v1::MaxPool::validate_and_infer_types() void op::v1::MaxPool::validate_and_infer_types()
{ {
if (0 == m_strides.size()) if (0 == m_strides.size())
......
...@@ -247,6 +247,7 @@ namespace ngraph ...@@ -247,6 +247,7 @@ namespace ngraph
const Shape& kernel, const Shape& kernel,
op::RoundingType rounding_mode); op::RoundingType rounding_mode);
bool visit_attributes(AttributeVisitor& visitor) override;
size_t get_version() const override { return 1; } size_t get_version() const override { return 1; }
void validate_and_infer_types() override; void validate_and_infer_types() override;
......
...@@ -27,6 +27,11 @@ op::Negative::Negative(const Output<Node>& arg) ...@@ -27,6 +27,11 @@ op::Negative::Negative(const Output<Node>& arg)
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
bool ngraph::op::v0::Negative::visit_attributes(AttributeVisitor& visitor)
{
return true;
}
shared_ptr<Node> op::Negative::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::Negative::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -37,6 +37,7 @@ namespace ngraph ...@@ -37,6 +37,7 @@ namespace ngraph
/// \param arg Node that produces the input tensor. /// \param arg Node that produces the input tensor.
Negative(const Output<Node>& arg); Negative(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/op/non_max_suppression.hpp" #include "ngraph/op/non_max_suppression.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
using namespace std; using namespace std;
...@@ -65,6 +66,13 @@ shared_ptr<Node> op::v1::NonMaxSuppression::copy_with_new_args(const NodeVector& ...@@ -65,6 +66,13 @@ shared_ptr<Node> op::v1::NonMaxSuppression::copy_with_new_args(const NodeVector&
m_sort_result_descending); m_sort_result_descending);
} }
bool ngraph::op::v1::NonMaxSuppression::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("box_encoding", m_box_encoding);
visitor.on_attribute("sort_result_descending", m_sort_result_descending);
return true;
}
void op::v1::NonMaxSuppression::validate_and_infer_types() void op::v1::NonMaxSuppression::validate_and_infer_types()
{ {
const auto boxes_ps = get_input_partial_shape(0); const auto boxes_ps = get_input_partial_shape(0);
...@@ -157,3 +165,26 @@ int64_t op::v1::NonMaxSuppression::max_boxes_output_from_input() const ...@@ -157,3 +165,26 @@ int64_t op::v1::NonMaxSuppression::max_boxes_output_from_input() const
return max_output_boxes; return max_output_boxes;
} }
namespace ngraph
{
template <>
EnumNames<op::v1::NonMaxSuppression::BoxEncodingType>&
EnumNames<op::v1::NonMaxSuppression::BoxEncodingType>::get()
{
static auto enum_names = EnumNames<op::v1::NonMaxSuppression::BoxEncodingType>(
"op::v1::NonMaxSuppression::BoxEncodingType",
{{"corner", op::v1::NonMaxSuppression::BoxEncodingType::CORNER},
{"center", op::v1::NonMaxSuppression::BoxEncodingType::CENTER}});
return enum_names;
}
constexpr DiscreteTypeInfo
AttributeAdapter<op::v1::NonMaxSuppression::BoxEncodingType>::type_info;
std::ostream& operator<<(std::ostream& s,
const op::v1::NonMaxSuppression::BoxEncodingType& type)
{
return s << as_string(type);
}
} // namespace ngraph
...@@ -68,6 +68,7 @@ namespace ngraph ...@@ -68,6 +68,7 @@ namespace ngraph
const BoxEncodingType box_encoding = BoxEncodingType::CORNER, const BoxEncodingType box_encoding = BoxEncodingType::CORNER,
const bool sort_result_descending = true); const bool sort_result_descending = true);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override; std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
...@@ -93,4 +94,22 @@ namespace ngraph ...@@ -93,4 +94,22 @@ namespace ngraph
}; };
} }
} }
std::ostream& operator<<(std::ostream& s,
const op::v1::NonMaxSuppression::BoxEncodingType& type);
template <>
class NGRAPH_API AttributeAdapter<op::v1::NonMaxSuppression::BoxEncodingType>
: public EnumAttributeAdapterBase<op::v1::NonMaxSuppression::BoxEncodingType>
{
public:
AttributeAdapter(op::v1::NonMaxSuppression::BoxEncodingType& value)
: EnumAttributeAdapterBase<op::v1::NonMaxSuppression::BoxEncodingType>(value)
{
}
static constexpr DiscreteTypeInfo type_info{
"AttributeAdapter<op::v1::NonMaxSuppression::BoxEncodingType>", 1};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
} }
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/op/one_hot.hpp" #include "ngraph/op/one_hot.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/validation_util.hpp" #include "ngraph/validation_util.hpp"
using namespace std; using namespace std;
...@@ -194,6 +195,12 @@ void op::v1::OneHot::validate_and_infer_types() ...@@ -194,6 +195,12 @@ void op::v1::OneHot::validate_and_infer_types()
set_output_type(0, on_value_et, result_shape); set_output_type(0, on_value_et, result_shape);
} }
bool ngraph::op::v1::OneHot::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("axis", m_axis);
return true;
}
shared_ptr<Node> op::v1::OneHot::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::v1::OneHot::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
......
...@@ -98,6 +98,7 @@ namespace ngraph ...@@ -98,6 +98,7 @@ namespace ngraph
const Output<Node>& off_value, const Output<Node>& off_value,
int64_t axis); int64_t axis);
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override; copy_with_new_args(const NodeVector& new_args) const override;
void validate_and_infer_types() override; void validate_and_infer_types() override;
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
//***************************************************************************** //*****************************************************************************
#include "ngraph/op/pad.hpp" #include "ngraph/op/pad.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/broadcast.hpp" #include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
...@@ -222,6 +223,12 @@ CoordinateDiff op::v1::Pad::get_pads_end() const ...@@ -222,6 +223,12 @@ CoordinateDiff op::v1::Pad::get_pads_end() const
return pads_end_coord; return pads_end_coord;
} }
bool ngraph::op::v1::Pad::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("pad_mode", m_pad_mode);
return true;
}
void op::v1::Pad::validate_and_infer_types() void op::v1::Pad::validate_and_infer_types()
{ {
element::Type result_et; element::Type result_et;
......
...@@ -135,6 +135,7 @@ namespace ngraph ...@@ -135,6 +135,7 @@ namespace ngraph
/// \brief Constructs a generic padding operation. /// \brief Constructs a generic padding operation.
Pad() = default; Pad() = default;
bool visit_attributes(AttributeVisitor& visitor) override;
size_t get_version() const override { return 1; } size_t get_version() const override { return 1; }
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
...@@ -285,6 +285,171 @@ TEST(attributes, user_op) ...@@ -285,6 +285,171 @@ TEST(attributes, user_op)
EXPECT_EQ(g_oracle->get_ultra_parameters(), oracle->get_ultra_parameters()); EXPECT_EQ(g_oracle->get_ultra_parameters(), oracle->get_ultra_parameters());
} }
TEST(attributes, matmul_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::MatMul>();
auto A = make_shared<op::Parameter>(element::f32, Shape{0, 2});
auto B = make_shared<op::Parameter>(element::f32, Shape{2, 0});
bool transpose_a = true;
bool transpose_b = true;
auto matmul = make_shared<opset1::MatMul>(A, B, transpose_a, transpose_b);
NodeBuilder builder(matmul);
auto g_matmul = as_type_ptr<opset1::MatMul>(builder.create());
EXPECT_EQ(g_matmul->get_transpose_a(), matmul->get_transpose_a());
EXPECT_EQ(g_matmul->get_transpose_b(), matmul->get_transpose_b());
}
TEST(attributes, max_pool_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::MaxPool>();
auto data = make_shared<op::Parameter>(element::f32, Shape{64, 3, 5});
auto strides = Strides{2};
auto pads_begin = Shape{1};
auto pads_end = Shape{1};
auto kernel = Shape{1};
auto rounding_mode = op::RoundingType::FLOOR;
auto auto_pad = op::PadType::EXPLICIT;
auto max_pool = make_shared<opset1::MaxPool>(
data, strides, pads_begin, pads_end, kernel, rounding_mode, auto_pad);
NodeBuilder builder(max_pool);
auto g_max_pool = as_type_ptr<opset1::MaxPool>(builder.create());
EXPECT_EQ(g_max_pool->get_strides(), max_pool->get_strides());
EXPECT_EQ(g_max_pool->get_pads_begin(), max_pool->get_pads_begin());
EXPECT_EQ(g_max_pool->get_pads_end(), max_pool->get_pads_end());
EXPECT_EQ(g_max_pool->get_kernel(), max_pool->get_kernel());
EXPECT_EQ(g_max_pool->get_rounding_type(), max_pool->get_rounding_type());
EXPECT_EQ(g_max_pool->get_auto_pad(), max_pool->get_auto_pad());
}
TEST(attributes, mod_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::Mod>();
auto A = make_shared<op::Parameter>(element::f32, Shape{0, 2});
auto B = make_shared<op::Parameter>(element::f32, Shape{2, 0});
auto auto_broadcast = op::AutoBroadcastType::NUMPY;
auto mod = make_shared<opset1::Mod>(A, B, auto_broadcast);
NodeBuilder builder(mod);
auto g_mod = as_type_ptr<opset1::Mod>(builder.create());
EXPECT_EQ(g_mod->get_auto_broadcast(), mod->get_auto_broadcast());
}
TEST(attributes, non_max_suppression_op_custom_attributes)
{
FactoryRegistry<Node>::get().register_factory<opset1::NonMaxSuppression>();
auto boxes = make_shared<op::Parameter>(element::f32, Shape{1, 1, 4});
auto scores = make_shared<op::Parameter>(element::f32, Shape{1, 1, 1});
auto box_encoding = opset1::NonMaxSuppression::BoxEncodingType::CENTER;
bool sort_result_descending = false;
auto nms =
make_shared<opset1::NonMaxSuppression>(boxes, scores, box_encoding, sort_result_descending);
NodeBuilder builder(nms);
auto g_nms = as_type_ptr<opset1::NonMaxSuppression>(builder.create());
EXPECT_EQ(g_nms->get_box_encoding(), nms->get_box_encoding());
EXPECT_EQ(g_nms->get_sort_result_descending(), nms->get_sort_result_descending());
}
TEST(attributes, non_max_suppression_op_default_attributes)
{
FactoryRegistry<Node>::get().register_factory<opset1::NonMaxSuppression>();
auto boxes = make_shared<op::Parameter>(element::f32, Shape{1, 1, 4});
auto scores = make_shared<op::Parameter>(element::f32, Shape{1, 1, 1});
auto nms = make_shared<opset1::NonMaxSuppression>(boxes, scores);
NodeBuilder builder(nms);
auto g_nms = as_type_ptr<opset1::NonMaxSuppression>(builder.create());
EXPECT_EQ(g_nms->get_box_encoding(), nms->get_box_encoding());
EXPECT_EQ(g_nms->get_sort_result_descending(), nms->get_sort_result_descending());
}
TEST(attributes, normalize_l2_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::NormalizeL2>();
auto data = make_shared<op::Parameter>(element::i32, Shape{1});
const auto axes = make_shared<op::Constant>(element::i32, Shape{}, vector<int32_t>{0});
float eps{1e-6f};
auto eps_mode = op::EpsMode::ADD;
auto normalize_l2 = make_shared<opset1::NormalizeL2>(data, axes, eps, eps_mode);
NodeBuilder builder(normalize_l2);
auto g_normalize_l2 = as_type_ptr<opset1::NormalizeL2>(builder.create());
EXPECT_EQ(g_normalize_l2->get_eps(), normalize_l2->get_eps());
EXPECT_EQ(g_normalize_l2->get_eps_mode(), normalize_l2->get_eps_mode());
}
TEST(attributes, one_hot_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::OneHot>();
auto indices = make_shared<op::Parameter>(element::i64, Shape{1, 3, 2, 3});
auto depth = op::Constant::create(element::i64, Shape{}, {4});
auto on_value = op::Constant::create(element::f32, Shape{}, {1.0f});
auto off_value = op::Constant::create(element::f32, Shape{}, {0.0f});
int64_t axis = 3;
auto one_hot = make_shared<opset1::OneHot>(indices, depth, on_value, off_value, axis);
NodeBuilder builder(one_hot);
auto g_one_hot = as_type_ptr<opset1::OneHot>(builder.create());
EXPECT_EQ(g_one_hot->get_axis(), one_hot->get_axis());
}
TEST(attributes, pad_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::Pad>();
auto arg = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
auto pads_begin = make_shared<op::Parameter>(element::i64, Shape{1});
auto pads_end = make_shared<op::Parameter>(element::i64, Shape{1});
auto pad_mode = op::PadMode::EDGE;
auto pad = make_shared<opset1::Pad>(arg, pads_begin, pads_end, pad_mode);
NodeBuilder builder(pad);
auto g_pad = as_type_ptr<opset1::Pad>(builder.create());
EXPECT_EQ(g_pad->get_pad_mode(), pad->get_pad_mode());
}
TEST(attributes, psroi_pooling_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::PSROIPooling>();
auto input = make_shared<op::Parameter>(element::f32, Shape{1, 1024, 63, 38});
auto coords = make_shared<op::Parameter>(element::f32, Shape{300, 5});
const int64_t output_dim = 882;
const int64_t group_size = 3;
const float spatial_scale = 0.0625;
int spatial_bins_x = 1;
int spatial_bins_y = 1;
string mode = "Avg";
auto psroi_pool = make_shared<opset1::PSROIPooling>(
input, coords, output_dim, group_size, spatial_scale, spatial_bins_x, spatial_bins_y, mode);
NodeBuilder builder(psroi_pool);
auto g_psroi_pool = as_type_ptr<opset1::PSROIPooling>(builder.create());
EXPECT_EQ(g_psroi_pool->get_output_dim(), psroi_pool->get_output_dim());
EXPECT_EQ(g_psroi_pool->get_group_size(), psroi_pool->get_group_size());
EXPECT_EQ(g_psroi_pool->get_spatial_scale(), psroi_pool->get_spatial_scale());
EXPECT_EQ(g_psroi_pool->get_spatial_bins_x(), psroi_pool->get_spatial_bins_x());
EXPECT_EQ(g_psroi_pool->get_spatial_bins_y(), psroi_pool->get_spatial_bins_y());
EXPECT_EQ(g_psroi_pool->get_mode(), psroi_pool->get_mode());
}
TEST(attributes, reduce_logical_and_op) TEST(attributes, reduce_logical_and_op)
{ {
// ReduceLogicalAnd derives visit_attributes from op::util::LogicalReductionKeepDims // ReduceLogicalAnd derives visit_attributes from op::util::LogicalReductionKeepDims
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment