Commit 206bc657 authored by Mateusz Bencer's avatar Mateusz Bencer Committed by Michał Karzyński

Calculate output shape for NonMaxSuppression (#4001)

parent 5dc3a1bb
......@@ -126,9 +126,60 @@ void op::v1::NonMaxSuppression::validate_and_infer_types()
// NonMaxSuppression produces triplets
// that have the following format: [batch_index, class_index, box_index]
// The number of returned triplets depends entirely on the computation, thus one dynamic dim
const PartialShape out_shape = {Dimension::dynamic(), 3};
PartialShape out_shape = {Dimension::dynamic(), 3};
const auto max_output_boxes_per_class = input_value(2).get_node_shared_ptr();
if (num_boxes_boxes.is_static() && scores_ps[1].is_static() &&
max_output_boxes_per_class->is_constant())
{
const auto num_boxes = static_cast<int64_t>(num_boxes_boxes);
const auto max_output_boxes_per_class = max_boxes_output_from_input();
const auto num_classes = static_cast<int64_t>(scores_ps[1]);
out_shape[0] = std::min(num_boxes, max_output_boxes_per_class * num_classes);
}
set_output_size(1);
set_output_type(0, element::i64, out_shape);
}
int64_t op::v1::NonMaxSuppression::max_boxes_output_from_input() const
{
int64_t max_output_boxes{0};
const auto max_output_boxes_input =
as_type_ptr<op::Constant>(input_value(2).get_node_shared_ptr());
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wswitch-enum"
#endif
switch (static_cast<element::Type_t>(max_output_boxes_input->get_element_type()))
{
case element::Type_t::i8:
{
max_output_boxes = max_output_boxes_input->get_vector<int8_t>().at(0);
break;
}
case element::Type_t::i16:
{
max_output_boxes = max_output_boxes_input->get_vector<int16_t>().at(0);
break;
}
case element::Type_t::i32:
{
max_output_boxes = max_output_boxes_input->get_vector<int32_t>().at(0);
break;
}
case element::Type_t::i64:
{
max_output_boxes = max_output_boxes_input->get_vector<int64_t>().at(0);
break;
}
default: break;
}
#if defined(__clang__)
#pragma clang diagnostic pop
#endif
return max_output_boxes;
}
......@@ -87,6 +87,9 @@ namespace ngraph
protected:
BoxEncodingType m_box_encoding = BoxEncodingType::CORNER;
bool m_sort_result_descending = true;
private:
int64_t max_boxes_output_from_input() const;
};
}
}
......
......@@ -121,7 +121,7 @@ TEST(type_prop, nms_scalar_inputs_check)
}
}
TEST(type_prop, nms_out_shape)
TEST(type_prop, nms_output_shape)
{
const auto boxes = make_shared<op::Parameter>(element::f32, Shape{1, 2, 4});
const auto scores = make_shared<op::Parameter>(element::f32, Shape{1, 2, 2});
......@@ -133,3 +133,33 @@ TEST(type_prop, nms_out_shape)
EXPECT_EQ(static_cast<size_t>(nms_out_ps.rank()), 2);
EXPECT_EQ(static_cast<size_t>(nms_out_ps[1]), 3);
}
TEST(type_prop, nms_output_shape_2)
{
const auto boxes = make_shared<op::Parameter>(element::f32, Shape{1, 6, 4});
const auto scores = make_shared<op::Parameter>(element::f32, Shape{1, 1, 6});
const auto max_output_boxes_per_class = op::Constant::create(element::i32, Shape{}, {3});
const auto iou_threshold = make_shared<op::Parameter>(element::f32, Shape{});
const auto score_threshold = make_shared<op::Parameter>(element::f32, Shape{});
const auto nms = make_shared<op::v1::NonMaxSuppression>(
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold);
ASSERT_EQ(nms->get_element_type(), element::i64);
ASSERT_EQ(nms->get_shape(), (Shape{3, 3}));
}
TEST(type_prop, nms_output_shape_3)
{
const auto boxes = make_shared<op::Parameter>(element::f32, Shape{1, 1, 4});
const auto scores = make_shared<op::Parameter>(element::f32, Shape{1, 1, 1});
const auto max_output_boxes_per_class = op::Constant::create(element::i16, Shape{}, {3});
const auto iou_threshold = make_shared<op::Parameter>(element::f32, Shape{});
const auto score_threshold = make_shared<op::Parameter>(element::f32, Shape{});
const auto nms = make_shared<op::v1::NonMaxSuppression>(
boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold);
ASSERT_EQ(nms->get_element_type(), element::i64);
ASSERT_EQ(nms->get_shape(), (Shape{1, 3}));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment