Unverified Commit e1ebbf12 authored by Tomasz Dołbniak's avatar Tomasz Dołbniak Committed by GitHub

RoiAlign operator - ONNX importer + core (#4432)

* ROIAlign op skeleton

* Expose ROIAlign via opset2

* ROIAlign attribute visitor

* Batch indices input for ROIAlign

* Shape inference for ROIAlign

* ROIAlign mode attribute validation

* onnx_importer ROIAlign op

* size_t of get_attribute_value in onnx_importer::Node

* Explicit cast size_t -> int64_t when creating Dimensions

* ...aaand some new lines to make clang format happy

* Rank check fix for the batch indices input.
Co-Authored-By: 's avatarAdam Osewski <adam.osewski@intel.com>

* Typo
Co-Authored-By: 's avatarAdam Osewski <adam.osewski@intel.com>

* Empty opset3 definition

* Move ROIAlign to opse3

* Review comments

* Correct rank check for batch indices input

* Basic shape inference test for ROIAlign

* Move ROIAlign to ops from experimental/layers

* ROIAlign -> RoiAlign
Co-Authored-By: 's avatarKatarzyna Mitrus <katarzyna.mitrus@intel.com>

* Move ROIAlign from v0 to v3

* Support more data types in ROIAlign

* Move the ROIAlign from opset2_tbl to opset3_tbl

* Don't include opset2 in opset3.hpp

* PoolingMode enum for ROIAlign attribute

* Use EnumNames to handle string/enum conversion

* More checks to prevent segfaults

* Remove the std prefix for string
Co-Authored-By: 's avatarAdam Osewski <adam.osewski@intel.com>

* Fix more potential segfaults

* Build break fix

* ROIAlign serializer

* PR feedback

* Disable the ROIAlign op in GPU backend

* ROIAlign type prop UT

* More shape inference UT

* PR comments for RoiAlign

* code formatting
Co-authored-by: 's avatarAdam Osewski <adam.osewski@intel.com>
Co-authored-by: 's avatarKatarzyna Mitrus <katarzyna.mitrus@intel.com>
Co-authored-by: 's avatarRobert Kimball <robert.kimball@intel.com>
parent a315f2dd
...@@ -301,6 +301,8 @@ set (SRC ...@@ -301,6 +301,8 @@ set (SRC
op/reduce_mean.hpp op/reduce_mean.hpp
op/reduce_sum.cpp op/reduce_sum.cpp
op/reduce_sum.hpp op/reduce_sum.hpp
op/roi_align.hpp
op/roi_align.cpp
op/round.cpp op/round.cpp
op/round.hpp op/round.hpp
op/quantize.cpp op/quantize.cpp
......
...@@ -177,6 +177,8 @@ add_library(onnx_import STATIC ...@@ -177,6 +177,8 @@ add_library(onnx_import STATIC
op/reshape.hpp op/reshape.hpp
op/reverse_sequence.cpp op/reverse_sequence.cpp
op/reverse_sequence.hpp op/reverse_sequence.hpp
op/roi_align.cpp
op/roi_align.hpp
op/round.cpp op/round.cpp
op/round.hpp op/round.hpp
op/scatter_nd.cpp op/scatter_nd.cpp
......
...@@ -220,6 +220,12 @@ namespace ngraph ...@@ -220,6 +220,12 @@ namespace ngraph
return m_pimpl->template get_attribute_value<std::int64_t>(name, default_value); return m_pimpl->template get_attribute_value<std::int64_t>(name, default_value);
} }
template <>
int Node::get_attribute_value(const std::string& name, int default_value) const
{
return m_pimpl->template get_attribute_value<int>(name, default_value);
}
template <> template <>
std::string Node::get_attribute_value(const std::string& name, std::string Node::get_attribute_value(const std::string& name,
std::string default_value) const std::string default_value) const
......
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include "ngraph/opsets/opset3.hpp"
#include "roi_align.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector roi_align(const Node& node)
{
const auto inputs = node.get_ng_inputs();
NGRAPH_CHECK(inputs.size() == 3,
"The RoiAlign operator expects 3 inputs. Got: ",
inputs.size());
const auto& data = inputs[0];
const auto& rois = inputs[1];
const auto& num_rois = inputs[2];
const auto pooled_h = node.get_attribute_value<int>("output_height", 1);
const auto pooled_w = node.get_attribute_value<int>("output_width", 1);
const auto sampling_ratio = node.get_attribute_value<int>("sampling_ratio", 1);
const auto spatial_scale =
node.get_attribute_value<float>("spatial_scale", 1.0f);
const auto mode = node.get_attribute_value<std::string>("mode", "avg");
return {std::make_shared<ngraph::opset3::ROIAlign>(data,
rois,
num_rois,
pooled_h,
pooled_w,
sampling_ratio,
spatial_scale,
mode)};
}
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "core/node.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector roi_align(const Node& node);
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph
...@@ -103,6 +103,7 @@ ...@@ -103,6 +103,7 @@
#include "op/relu.hpp" #include "op/relu.hpp"
#include "op/reshape.hpp" #include "op/reshape.hpp"
#include "op/reverse_sequence.hpp" #include "op/reverse_sequence.hpp"
#include "op/roi_align.hpp"
#include "op/round.hpp" #include "op/round.hpp"
#include "op/scatter_nd.hpp" #include "op/scatter_nd.hpp"
#include "op/selu.hpp" #include "op/selu.hpp"
...@@ -340,6 +341,7 @@ namespace ngraph ...@@ -340,6 +341,7 @@ namespace ngraph
REGISTER_OPERATOR("Relu", 1, relu); REGISTER_OPERATOR("Relu", 1, relu);
REGISTER_OPERATOR("Reshape", 1, reshape); REGISTER_OPERATOR("Reshape", 1, reshape);
REGISTER_OPERATOR("ReverseSequence", 1, reverse_sequence); REGISTER_OPERATOR("ReverseSequence", 1, reverse_sequence);
REGISTER_OPERATOR("RoiAlign", 1, roi_align);
REGISTER_OPERATOR("Round", 1, round); REGISTER_OPERATOR("Round", 1, round);
REGISTER_OPERATOR("ScatterND", 1, scatter_nd); REGISTER_OPERATOR("ScatterND", 1, scatter_nd);
REGISTER_OPERATOR("Selu", 1, selu); REGISTER_OPERATOR("Selu", 1, selu);
......
...@@ -205,6 +205,7 @@ NGRAPH_OP(Reverse, ngraph::op::v0, 0) ...@@ -205,6 +205,7 @@ NGRAPH_OP(Reverse, ngraph::op::v0, 0)
NGRAPH_OP(Reverse, ngraph::op::v1, 1) NGRAPH_OP(Reverse, ngraph::op::v1, 1)
NGRAPH_OP(ReverseSequence, ngraph::op::v0, 0) NGRAPH_OP(ReverseSequence, ngraph::op::v0, 0)
NGRAPH_OP(Round, ngraph::op::v0, 0) NGRAPH_OP(Round, ngraph::op::v0, 0)
NGRAPH_OP(ROIAlign, ngraph::op::v3, 3)
NGRAPH_OP(ScalarConstantLike, ngraph::op::v0, 0) NGRAPH_OP(ScalarConstantLike, ngraph::op::v0, 0)
NGRAPH_OP(ScaleShift, ngraph::op::v0, 0) NGRAPH_OP(ScaleShift, ngraph::op::v0, 0)
NGRAPH_OP(ScatterAdd, ngraph::op::v0, 0) NGRAPH_OP(ScatterAdd, ngraph::op::v0, 0)
......
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "roi_align.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::v3::ROIAlign::type_info;
op::v3::ROIAlign::ROIAlign(const Output<Node>& input,
const Output<Node>& rois,
const Output<Node>& batch_indices,
const int pooled_h,
const int pooled_w,
const int sampling_ratio,
const float spatial_scale,
const string& mode)
: Op{{input, rois, batch_indices}}
, m_pooled_h{pooled_h}
, m_pooled_w{pooled_w}
, m_sampling_ratio{sampling_ratio}
, m_spatial_scale{spatial_scale}
, m_mode{EnumNames<ROIAlign::PoolingMode>::as_enum(mode)}
{
constructor_validate_and_infer_types();
}
op::v3::ROIAlign::ROIAlign(const Output<Node>& input,
const Output<Node>& rois,
const Output<Node>& batch_indices,
const int pooled_h,
const int pooled_w,
const int sampling_ratio,
const float spatial_scale,
const PoolingMode mode)
: Op{{input, rois, batch_indices}}
, m_pooled_h{pooled_h}
, m_pooled_w{pooled_w}
, m_sampling_ratio{sampling_ratio}
, m_spatial_scale{spatial_scale}
, m_mode{mode}
{
constructor_validate_and_infer_types();
}
void op::v3::ROIAlign::validate_and_infer_types()
{
NODE_VALIDATION_CHECK(
this,
get_input_element_type(0).is_real() && get_input_element_type(1).is_real(),
"The data type for input and ROIs is expected to be a floating point type. Got: ",
get_input_element_type(0),
" and: ",
get_input_element_type(1));
NODE_VALIDATION_CHECK(this,
get_input_element_type(2).is_integral_number(),
"The data type for batch indices is expected to be an integer. Got: ",
get_input_element_type(2));
const auto& input_ps = get_input_partial_shape(0);
NODE_VALIDATION_CHECK(this,
input_ps.rank().compatible(4),
"Expected a 4D tensor for the input data. Got: ",
input_ps);
const auto& rois_ps = get_input_partial_shape(1);
NODE_VALIDATION_CHECK(this,
rois_ps.rank().compatible(2),
"Expected a 2D tensor for the ROIs input. Got: ",
rois_ps);
const auto& batch_indices_ps = get_input_partial_shape(2);
NODE_VALIDATION_CHECK(this,
batch_indices_ps.rank().compatible(1),
"Expected a 1D tensor for the batch indices input. Got: ",
batch_indices_ps);
if (rois_ps.rank().is_static())
{
const auto rois_second_dim = rois_ps[1];
NODE_VALIDATION_CHECK(this,
rois_second_dim.compatible(4),
"The second dimension of ROIs input should contain box coordinates. ",
"This dimension is expected to be equal to 4. Got: ",
rois_second_dim);
if (batch_indices_ps.rank().is_static())
{
NODE_VALIDATION_CHECK(
this,
rois_ps[0].compatible(batch_indices_ps[0]),
"The first dimension of ROIs input must be equal to the first dimension ",
"of the batch indices input. Got: ",
rois_ps[0],
" and: ",
batch_indices_ps[0]);
}
}
// the output shape should have the following format [NUM_ROIS, C, pooled_h, pooled_w]
auto output_shape = PartialShape{{Dimension::dynamic(),
input_ps[1],
Dimension{static_cast<int64_t>(m_pooled_h)},
Dimension{static_cast<int64_t>(m_pooled_w)}}};
// if either of those 2 dimensions is static its value will be used
// for the first dimension of the output shape - 'NUM_ROIS'
if (rois_ps.rank().is_static() && rois_ps[0].is_static())
{
output_shape[0] = rois_ps[0];
}
else if (batch_indices_ps.rank().is_static() && batch_indices_ps[0].is_static())
{
output_shape[0] = batch_indices_ps[0];
}
set_output_size(1);
set_output_type(0, get_input_element_type(0), output_shape);
// if the channels dimension is not known
// the first input should be used during the function specialization
if (input_ps.rank().is_static() && input_ps[1].is_dynamic())
{
set_input_is_relevant_to_shape(0);
}
// if the 'NUM_ROIS' value is not known
// the last 2 inputs should be used during the function specialization
if (output_shape[0].is_dynamic())
{
set_input_is_relevant_to_shape(1);
set_input_is_relevant_to_shape(2);
}
}
bool op::v3::ROIAlign::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("pooled_h", m_pooled_h);
visitor.on_attribute("pooled_w", m_pooled_w);
visitor.on_attribute("sampling_ratio", m_sampling_ratio);
visitor.on_attribute("spatial_scale", m_spatial_scale);
visitor.on_attribute("mode", m_mode);
return true;
}
shared_ptr<Node> op::v3::ROIAlign::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<ROIAlign>(new_args.at(0),
new_args.at(1),
new_args.at(2),
m_pooled_h,
m_pooled_w,
m_sampling_ratio,
m_spatial_scale,
m_mode);
}
namespace ngraph
{
constexpr DiscreteTypeInfo AttributeAdapter<op::v3::ROIAlign::PoolingMode>::type_info;
template <>
EnumNames<op::v3::ROIAlign::PoolingMode>& EnumNames<op::v3::ROIAlign::PoolingMode>::get()
{
static auto enum_names =
EnumNames<op::v3::ROIAlign::PoolingMode>("op::v3::ROIAlign::PoolingMode",
{{"avg", op::v3::ROIAlign::PoolingMode::AVG},
{"max", op::v3::ROIAlign::PoolingMode::MAX}});
return enum_names;
}
std::ostream& operator<<(std::ostream& s, const op::v3::ROIAlign::PoolingMode& type)
{
return s << as_string(type);
}
}
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
namespace v3
{
class NGRAPH_API ROIAlign : public Op
{
public:
enum class PoolingMode
{
AVG,
MAX
};
static constexpr NodeTypeInfo type_info{"ROIAlign", 3};
const NodeTypeInfo& get_type_info() const override { return type_info; }
ROIAlign() = default;
/// \brief Constructs a ROIAlign node matching the ONNX ROIAlign specification
///
/// \param input Input feature map {N, C, H, W}
/// \param rois Regions of interest to pool over
/// \param batch_indices Indices of images in the batch matching
/// the number or ROIs
/// \param pooled_h Height of the ROI output features
/// \param pooled_w Width of the ROI output features
/// \param sampling_ratio Number of sampling points used to compute
/// an output element
/// \param spatial_scale Spatial scale factor used to translate ROI coordinates
/// \param mode Method of pooling - 'avg' or 'max'
ROIAlign(const Output<Node>& input,
const Output<Node>& rois,
const Output<Node>& batch_indices,
const int pooled_h,
const int pooled_w,
const int sampling_ratio,
const float spatial_scale,
const std::string& mode);
ROIAlign(const Output<Node>& input,
const Output<Node>& rois,
const Output<Node>& batch_indices,
const int pooled_h,
const int pooled_w,
const int sampling_ratio,
const float spatial_scale,
const PoolingMode mode);
virtual void validate_and_infer_types() override;
virtual bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
int get_pooled_h() const { return m_pooled_h; }
int get_pooled_w() const { return m_pooled_w; }
int get_sampling_ratio() const { return m_sampling_ratio; }
float get_spatial_scale() const { return m_spatial_scale; }
PoolingMode get_mode() const { return m_mode; }
private:
PoolingMode mode_from_string(const std::string& mode) const;
private:
int m_pooled_h;
int m_pooled_w;
int m_sampling_ratio;
float m_spatial_scale;
PoolingMode m_mode;
};
}
using v3::ROIAlign;
}
std::ostream& operator<<(std::ostream& s, const op::v3::ROIAlign::PoolingMode& mode);
template <>
class NGRAPH_API AttributeAdapter<op::v3::ROIAlign::PoolingMode>
: public EnumAttributeAdapterBase<op::v3::ROIAlign::PoolingMode>
{
public:
AttributeAdapter(op::v3::ROIAlign::PoolingMode& value)
: EnumAttributeAdapterBase<op::v3::ROIAlign::PoolingMode>(value)
{
}
static constexpr DiscreteTypeInfo type_info{
"AttributeAdapter<op::v3::ROIAlign::PoolingMode>", 3};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
}
...@@ -160,6 +160,7 @@ ...@@ -160,6 +160,7 @@
#include "ngraph/op/result.hpp" #include "ngraph/op/result.hpp"
#include "ngraph/op/reverse.hpp" #include "ngraph/op/reverse.hpp"
#include "ngraph/op/reverse_sequence.hpp" #include "ngraph/op/reverse_sequence.hpp"
#include "ngraph/op/roi_align.hpp"
#include "ngraph/op/round.hpp" #include "ngraph/op/round.hpp"
#include "ngraph/op/scatter_add.hpp" #include "ngraph/op/scatter_add.hpp"
#include "ngraph/op/scatter_nd_add.hpp" #include "ngraph/op/scatter_nd_add.hpp"
......
...@@ -21,3 +21,4 @@ ...@@ -21,3 +21,4 @@
#include "opset2_tbl.hpp" #include "opset2_tbl.hpp"
NGRAPH_OP(NonZero, ngraph::op::v3) NGRAPH_OP(NonZero, ngraph::op::v3)
NGRAPH_OP(ROIAlign, ngraph::op::v3)
...@@ -1995,3 +1995,8 @@ std::string runtime::gpu::GPU_Emitter::emit_v3_NonZero(EMIT_ARGS) ...@@ -1995,3 +1995,8 @@ std::string runtime::gpu::GPU_Emitter::emit_v3_NonZero(EMIT_ARGS)
{ {
throw unsupported_op("Unsupported op '" + node->description() + "'"); throw unsupported_op("Unsupported op '" + node->description() + "'");
} }
std::string runtime::gpu::GPU_Emitter::emit_v3_ROIAlign(EMIT_ARGS)
{
throw unsupported_op("Unsupported op '" + node->description() + "'");
}
...@@ -2629,6 +2629,19 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js) ...@@ -2629,6 +2629,19 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
} }
break; break;
} }
case OP_TYPEID::ROIAlign_v3:
{
const auto pooled_h = node_js.at("pooled_h").get<size_t>();
const auto pooled_w = node_js.at("pooled_w").get<size_t>();
const auto sampling_ratio = node_js.at("sampling_ratio").get<size_t>();
const auto spatial_scale = node_js.at("spatial_scale").get<float>();
const auto mode = node_js.at("mode").get<op::ROIAlign::PoolingMode>();
node = make_shared<op::ROIAlign>(
args[0], args[1], args[2], pooled_h, pooled_w, sampling_ratio, spatial_scale, mode);
break;
}
case OP_TYPEID::ROIPooling: { break; case OP_TYPEID::ROIPooling: { break;
} }
case OP_TYPEID::RegionYolo: { break; case OP_TYPEID::RegionYolo: { break;
...@@ -4386,6 +4399,16 @@ json JSONSerializer::serialize_node(const Node& n) ...@@ -4386,6 +4399,16 @@ json JSONSerializer::serialize_node(const Node& n)
node["activations_beta"] = tmp->get_activations_beta(); node["activations_beta"] = tmp->get_activations_beta();
break; break;
} }
case OP_TYPEID::ROIAlign_v3:
{
auto tmp = static_cast<const op::ROIAlign*>(&n);
node["pooled_h"] = tmp->get_pooled_h();
node["pooled_w"] = tmp->get_pooled_w();
node["sampling_ratio"] = tmp->get_sampling_ratio();
node["spatial_scale"] = tmp->get_spatial_scale();
node["mode"] = tmp->get_mode();
break;
}
case OP_TYPEID::ScalarConstantLike: case OP_TYPEID::ScalarConstantLike:
{ {
auto tmp = static_cast<const op::ScalarConstantLike*>(&n); auto tmp = static_cast<const op::ScalarConstantLike*>(&n);
......
...@@ -172,6 +172,7 @@ set(SRC ...@@ -172,6 +172,7 @@ set(SRC
type_prop/reshape.cpp type_prop/reshape.cpp
type_prop/reverse.cpp type_prop/reverse.cpp
type_prop/reverse_sequence.cpp type_prop/reverse_sequence.cpp
type_prop/roi_align.cpp
type_prop/rnn_cell.cpp type_prop/rnn_cell.cpp
type_prop/scale_shift.cpp type_prop/scale_shift.cpp
type_prop/scatter_add.cpp type_prop/scatter_add.cpp
......
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
using namespace std;
using namespace ngraph;
TEST(type_prop_layers, roi_align_basic_shape_inference)
{
const auto data = make_shared<op::Parameter>(element::f32, Shape{2, 3, 5, 5});
const auto rois = make_shared<op::Parameter>(element::f32, Shape{7, 4});
const auto batch_indices = make_shared<op::Parameter>(element::i32, Shape{7});
const auto op = make_shared<op::v3::ROIAlign>(data, rois, batch_indices, 2, 2, 1, 1.0f, "avg");
ASSERT_EQ(op->get_shape(), (Shape{7, 3, 2, 2}));
}
TEST(type_prop_layers, roi_align_dynamic_channels_dim)
{
const auto data = make_shared<op::Parameter>(element::f32, PartialShape{10, Dimension(), 5, 5});
const auto rois = make_shared<op::Parameter>(element::f32, Shape{7, 4});
const auto batch_indices = make_shared<op::Parameter>(element::i32, Shape{7});
const auto op = make_shared<op::v3::ROIAlign>(data, rois, batch_indices, 3, 4, 1, 1.0f, "avg");
ASSERT_TRUE(op->get_output_partial_shape(0).same_scheme(PartialShape{7, Dimension(), 3, 4}));
}
TEST(type_prop_layers, roi_align_num_rois_from_batch_indices)
{
const auto data = make_shared<op::Parameter>(element::f32, PartialShape{10, 3, 5, 5});
const auto rois =
make_shared<op::Parameter>(element::f32, PartialShape{Dimension{}, Dimension{}});
const auto batch_indices = make_shared<op::Parameter>(element::i32, Shape{9});
const auto op = make_shared<op::v3::ROIAlign>(data, rois, batch_indices, 3, 4, 1, 1.0f, "avg");
ASSERT_EQ(op->get_shape(), (Shape{9, 3, 3, 4}));
}
TEST(type_prop_layers, roi_align_incompatible_num_rois)
{
const auto data = make_shared<op::Parameter>(element::f32, Shape{10, 3, 5, 5});
const auto rois = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension{}});
const auto batch_indices = make_shared<op::Parameter>(element::i32, Shape{2});
// the first dimension of rois and batch_indices should be equal
ASSERT_THROW(make_shared<op::v3::ROIAlign>(data, rois, batch_indices, 3, 4, 1, 1.0f, "avg"),
ngraph::NodeValidationFailure);
}
TEST(type_prop_layers, roi_align_incompatible_input_rank)
{
const auto data = make_shared<op::Parameter>(element::f32, Shape{1, 10, 3, 5, 5});
const auto rois = make_shared<op::Parameter>(element::f32, Shape{1, 4});
const auto batch_indices = make_shared<op::Parameter>(element::i32, Shape{1});
// data rank needs to be 4
ASSERT_THROW(make_shared<op::v3::ROIAlign>(data, rois, batch_indices, 3, 4, 1, 1.0f, "avg"),
ngraph::NodeValidationFailure);
}
TEST(type_prop_layers, roi_align_incompatible_rois_second_dim)
{
const auto data = make_shared<op::Parameter>(element::f32, Shape{10, 3, 5, 5});
const auto rois = make_shared<op::Parameter>(element::f32, Shape{1, 5});
const auto batch_indices = make_shared<op::Parameter>(element::i32, Shape{1});
// the second dim of rois needs to be 4
ASSERT_THROW(make_shared<op::v3::ROIAlign>(data, rois, batch_indices, 3, 4, 1, 1.0f, "avg"),
ngraph::NodeValidationFailure);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment