Commit f12a5c92 authored by Tomasz Dołbniak's avatar Tomasz Dołbniak Committed by Scott Cyphers

[SPEC] NonMaxSuppression (#3968)

* NonMaxSuppression op skeleton

* Validation of the NonMaxSuppresion op

* Correct last 'boxes' dimention check

* onnx_importer support for NonMaxSuppression

* Code formatting

* Type and shape inference for NonMaxSuppression

* Different initialization of NMS inputs in onnx_importer

* Code formatting

* Basic type_prop tests for NonMaxSuppression

* More type_prop validation for NMS
parent edc5d6ba
......@@ -259,6 +259,8 @@ set (SRC
op/multiply.hpp
op/negative.cpp
op/negative.hpp
op/non_max_suppression.cpp
op/non_max_suppression.hpp
op/not.cpp
op/not.hpp
op/not_equal.cpp
......
......@@ -140,6 +140,8 @@ add_library(onnx_import STATIC
op/mul.hpp
op/neg.hpp
op/not.hpp
op/non_max_suppression.cpp
op/non_max_suppression.hpp
op/onehot.cpp
op/onehot.hpp
op/or.hpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include "ngraph/frontend/onnx_import/exceptions.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/non_max_suppression.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "non_max_suppression.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector non_max_suppression(const Node& node)
{
// TODO: this op will not be tested until at least
// a reference implementation is added
const auto ng_inputs = node.get_ng_inputs();
const std::shared_ptr<ngraph::Node> boxes = ng_inputs.at(0);
const std::shared_ptr<ngraph::Node> scores = ng_inputs.at(1);
std::shared_ptr<ngraph::Node> max_output_boxes_per_class;
if (ng_inputs.size() > 2)
{
max_output_boxes_per_class = ng_inputs.at(2);
}
else
{
max_output_boxes_per_class =
ngraph::op::Constant::create(element::i64, Shape{}, {0});
}
std::shared_ptr<ngraph::Node> iou_threshold;
if (ng_inputs.size() > 3)
{
iou_threshold = ng_inputs.at(3);
}
else
{
iou_threshold = ngraph::op::Constant::create(element::f32, Shape{}, {.0f});
}
std::shared_ptr<ngraph::Node> score_threshold;
if (ng_inputs.size() > 4)
{
score_threshold = ng_inputs.at(4);
}
else
{
score_threshold =
ngraph::op::Constant::create(element::f32, Shape{}, {.0f});
}
const auto center_point_box =
node.get_attribute_value<std::int64_t>("center_point_box", 0);
ASSERT_IS_SUPPORTED(node, center_point_box == 0 || center_point_box == 1)
<< "Allowed values of the 'center_point_box' attribute are 0 and 1.";
const auto box_encoding =
center_point_box == 0
? ngraph::op::v1::NonMaxSuppression::BoxEncodingType::CORNER
: ngraph::op::v1::NonMaxSuppression::BoxEncodingType::CENTER;
return {std::make_shared<ngraph::op::v1::NonMaxSuppression>(
boxes,
scores,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
box_encoding,
false)};
}
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "core/node.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/and.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector non_max_suppression(const Node& node);
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph
......@@ -85,6 +85,7 @@
#include "op/mod.hpp"
#include "op/mul.hpp"
#include "op/neg.hpp"
#include "op/non_max_suppression.hpp"
#include "op/not.hpp"
#include "op/onehot.hpp"
#include "op/or.hpp"
......@@ -306,6 +307,7 @@ namespace ngraph
REGISTER_OPERATOR("Mul", 1, mul);
REGISTER_OPERATOR("Mul", 7, mul);
REGISTER_OPERATOR("Neg", 1, neg);
REGISTER_OPERATOR("NonMaxSuppression", 1, non_max_suppression);
REGISTER_OPERATOR("Not", 1, logical_not);
REGISTER_OPERATOR("Or", 1, logical_or);
REGISTER_OPERATOR("OneHot", 1, onehot);
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/non_max_suppression.hpp"
#include "ngraph/op/constant.hpp"
using namespace std;
using namespace ngraph;
constexpr NodeTypeInfo op::v1::NonMaxSuppression::type_info;
op::v1::NonMaxSuppression::NonMaxSuppression(
const Output<Node>& boxes,
const Output<Node>& scores,
const Output<Node>& max_output_boxes_per_class,
const Output<Node>& iou_threshold,
const Output<Node>& score_threshold,
const op::v1::NonMaxSuppression::BoxEncodingType box_encoding,
const bool sort_result_descending)
: Op({boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold})
, m_box_encoding{box_encoding}
, m_sort_result_descending{sort_result_descending}
{
constructor_validate_and_infer_types();
}
op::v1::NonMaxSuppression::NonMaxSuppression(
const Output<Node>& boxes,
const Output<Node>& scores,
const op::v1::NonMaxSuppression::BoxEncodingType box_encoding,
const bool sort_result_descending)
: Op({boxes,
scores,
op::Constant::create(element::i64, Shape{}, {0}),
op::Constant::create(element::f32, Shape{}, {.0f}),
op::Constant::create(element::f32, Shape{}, {.0f})})
, m_box_encoding{box_encoding}
, m_sort_result_descending{sort_result_descending}
{
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v1::NonMaxSuppression::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<op::v1::NonMaxSuppression>(new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
new_args.at(4),
m_box_encoding,
m_sort_result_descending);
}
void op::v1::NonMaxSuppression::validate_and_infer_types()
{
const auto boxes_ps = get_input_partial_shape(0);
const auto scores_ps = get_input_partial_shape(1);
NODE_VALIDATION_CHECK(this,
boxes_ps.rank().is_static() && static_cast<size_t>(boxes_ps.rank()) == 3,
"Expected a 3D tensor for the 'boxes' input. Got: ",
boxes_ps);
NODE_VALIDATION_CHECK(this,
scores_ps.rank().is_static() &&
static_cast<size_t>(scores_ps.rank()) == 3,
"Expected a 3D tensor for the 'scores' input. Got: ",
scores_ps);
const auto max_boxes_ps = get_input_partial_shape(2);
NODE_VALIDATION_CHECK(this,
max_boxes_ps.is_dynamic() || is_scalar(max_boxes_ps.to_shape()),
"Expected a scalar for the 'max_output_boxes_per_class' input. Got: ",
max_boxes_ps);
const auto iou_threshold_ps = get_input_partial_shape(3);
NODE_VALIDATION_CHECK(this,
iou_threshold_ps.is_dynamic() || is_scalar(iou_threshold_ps.to_shape()),
"Expected a scalar for the 'iou_threshold' input. Got: ",
iou_threshold_ps);
const auto score_threshold_ps = get_input_partial_shape(4);
NODE_VALIDATION_CHECK(this,
score_threshold_ps.is_dynamic() ||
is_scalar(score_threshold_ps.to_shape()),
"Expected a scalar for the 'score_threshold' input. Got: ",
score_threshold_ps);
const auto num_batches_boxes = boxes_ps[0];
const auto num_batches_scores = scores_ps[0];
NODE_VALIDATION_CHECK(this,
num_batches_boxes.same_scheme(num_batches_scores),
"The first dimension of both 'boxes' and 'scores' must match. Boxes: ",
num_batches_boxes,
"; Scores: ",
num_batches_scores);
const auto num_boxes_boxes = boxes_ps[1];
const auto num_boxes_scores = scores_ps[2];
NODE_VALIDATION_CHECK(this,
num_boxes_boxes.same_scheme(num_boxes_scores),
"'boxes' and 'scores' input shapes must match at the second and third "
"dimension respectively. Boxes: ",
num_boxes_boxes,
"; Scores: ",
num_boxes_scores);
NODE_VALIDATION_CHECK(this,
boxes_ps[2].is_static() && static_cast<size_t>(boxes_ps[2]) == 4u,
"The last dimension of the 'boxes' input must be equal to 4. Got:",
boxes_ps[2]);
// NonMaxSuppression produces triplets
// that have the following format: [batch_index, class_index, box_index]
// The number of returned triplets depends entirely on the computation, thus one dynamic dim
const PartialShape out_shape = {Dimension::dynamic(), 3};
set_output_size(1);
set_output_type(0, element::i64, out_shape);
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
namespace v1
{
/// \brief Elementwise addition operation.
///
class NGRAPH_API NonMaxSuppression : public Op
{
public:
enum class BoxEncodingType
{
CORNER,
CENTER
};
static constexpr NodeTypeInfo type_info{"NonMaxSuppression", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
NonMaxSuppression() = default;
/// \brief Constructs a NonMaxSuppression operation.
///
/// \param boxes Output that produces a tensor with box coordinates.
/// \param scores Output that produces ta tensor
/// \param max_output_boxes_per_class Auto broadcast specification
/// \param iou_threshold Auto broadcast specification
/// \param score_threshold Auto broadcast specification
/// \param box_encoding Auto broadcast specification
/// \param sort_result_descending Auto broadcast specification
NonMaxSuppression(const Output<Node>& boxes,
const Output<Node>& scores,
const Output<Node>& max_output_boxes_per_class,
const Output<Node>& iou_threshold,
const Output<Node>& score_threshold,
const BoxEncodingType box_encoding = BoxEncodingType::CORNER,
const bool sort_result_descending = true);
/// \brief Constructs a NonMaxSuppression operation with default values for the last
/// 3 inputs
///
/// \param boxes Output that produces a tensor with box coordinates.
/// \param scores Output that produces ta tensor
/// \param box_encoding Auto broadcast specification
/// \param sort_result_descending Auto broadcast specification
NonMaxSuppression(const Output<Node>& boxes,
const Output<Node>& scores,
const BoxEncodingType box_encoding = BoxEncodingType::CORNER,
const bool sort_result_descending = true);
void validate_and_infer_types() override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
BoxEncodingType get_box_encoding() const { return m_box_encoding; }
void set_box_encoding(const BoxEncodingType box_encoding)
{
m_box_encoding = box_encoding;
}
bool get_sort_result_descending() const { return m_sort_result_descending; }
void set_sort_result_descending(const bool sort_result_descending)
{
m_sort_result_descending = sort_result_descending;
}
protected:
BoxEncodingType m_box_encoding = BoxEncodingType::CORNER;
bool m_sort_result_descending = true;
};
}
}
}
......@@ -153,6 +153,7 @@ NGRAPH_OP(Mod, ngraph::op::v1, 1)
NGRAPH_OP(Multiply, ngraph::op::v0, 0)
NGRAPH_OP(Multiply, ngraph::op::v1, 1)
NGRAPH_OP(Negative, ngraph::op, 0)
NGRAPH_OP(NonMaxSuppression, ngraph::op::v1, 1)
NGRAPH_OP(NormalizeL2, ngraph::op::v0, 0)
NGRAPH_OP(Not, ngraph::op::v0, 0)
NGRAPH_OP(NotEqual, ngraph::op::v0, 0)
......
......@@ -135,6 +135,7 @@
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/non_max_suppression.hpp"
#include "ngraph/op/not.hpp"
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/one_hot.hpp"
......
......@@ -108,6 +108,7 @@ NGRAPH_OP(Minimum, ngraph::op::v1)
NGRAPH_OP(Mod, ngraph::op::v1)
NGRAPH_OP(Multiply, ngraph::op::v1)
NGRAPH_OP(Negative, ngraph::op::v0)
NGRAPH_OP(NonMaxSuppression, ngraph::op::v1)
NGRAPH_OP(NormalizeL2, ngraph::op::v0)
NGRAPH_OP(NotEqual, ngraph::op::v1)
NGRAPH_OP(OneHot, ngraph::op::v1)
......
......@@ -2210,6 +2210,17 @@ shared_ptr<Node> JSONDeserializer::deserialize_node(json node_js)
node = make_shared<op::Negative>(args[0]);
break;
}
case OP_TYPEID::NonMaxSuppression_v1:
{
const auto box_encoding =
node_js.at("box_encoding").get<op::v1::NonMaxSuppression::BoxEncodingType>();
const auto sort_result_descending = node_js.at("sort_result_descending").get<bool>();
node = make_shared<op::v1::NonMaxSuppression>(
args[0], args[1], args[2], args[3], args[4], box_encoding, sort_result_descending);
break;
}
case OP_TYPEID::NormalizeL2:
{
float eps = node_js.at("eps").get<float>();
......@@ -4042,6 +4053,13 @@ json JSONSerializer::serialize_node(const Node& n)
}
case OP_TYPEID::Negative: { break;
}
case OP_TYPEID::NonMaxSuppression_v1:
{
const auto tmp = static_cast<const op::v1::NonMaxSuppression*>(&n);
node["box_encoding"] = tmp->get_box_encoding();
node["sort_result_descending"] = tmp->get_sort_result_descending();
break;
}
case OP_TYPEID::NormalizeL2:
{
auto tmp = static_cast<const op::NormalizeL2*>(&n);
......
......@@ -152,6 +152,7 @@ set(SRC
type_prop/matmul.cpp
type_prop/max_pool.cpp
type_prop/mvn.cpp
type_prop/non_max_suppression.cpp
type_prop/normalize.cpp
type_prop/one_hot.cpp
type_prop/pad.cpp
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(type_prop, nms_incorrect_boxes_rank)
{
try
{
const auto boxes = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3, 4});
const auto scores = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
make_shared<op::v1::NonMaxSuppression>(boxes, scores);
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Expected a 3D tensor for the 'boxes' input");
}
}
TEST(type_prop, nms_incorrect_scores_rank)
{
try
{
const auto boxes = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
const auto scores = make_shared<op::Parameter>(element::f32, Shape{1, 2});
make_shared<op::v1::NonMaxSuppression>(boxes, scores);
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Expected a 3D tensor for the 'scores' input");
}
}
TEST(type_prop, nms_incorrect_scheme_num_batches)
{
try
{
const auto boxes = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
const auto scores = make_shared<op::Parameter>(element::f32, Shape{2, 2, 3});
make_shared<op::v1::NonMaxSuppression>(boxes, scores);
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"The first dimension of both 'boxes' and 'scores' must match");
}
}
TEST(type_prop, nms_incorrect_scheme_num_boxes)
{
try
{
const auto boxes = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
const auto scores = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3});
make_shared<op::v1::NonMaxSuppression>(boxes, scores);
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"'boxes' and 'scores' input shapes must match at the second and third "
"dimension respectively");
}
}
TEST(type_prop, nms_scalar_inputs_check)
{
const auto boxes = make_shared<op::Parameter>(element::f32, Shape{1, 2, 4});
const auto scores = make_shared<op::Parameter>(element::f32, Shape{1, 2, 2});
const auto scalar = make_shared<op::Parameter>(element::f32, Shape{});
const auto non_scalar = make_shared<op::Parameter>(element::f32, Shape{1});
try
{
make_shared<op::v1::NonMaxSuppression>(boxes, scores, non_scalar, scalar, scalar);
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
"Expected a scalar for the 'max_output_boxes_per_class' input");
}
try
{
make_shared<op::v1::NonMaxSuppression>(boxes, scores, scalar, non_scalar, scalar);
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Expected a scalar for the 'iou_threshold' input");
}
try
{
make_shared<op::v1::NonMaxSuppression>(boxes, scores, scalar, scalar, non_scalar);
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(), "Expected a scalar for the 'score_threshold' input");
}
}
TEST(type_prop, nms_out_shape)
{
const auto boxes = make_shared<op::Parameter>(element::f32, Shape{1, 2, 4});
const auto scores = make_shared<op::Parameter>(element::f32, Shape{1, 2, 2});
const auto nms = make_shared<op::v1::NonMaxSuppression>(boxes, scores);
const auto nms_out_ps = nms->output(0).get_partial_shape();
EXPECT_TRUE(nms_out_ps.rank().is_static());
EXPECT_EQ(static_cast<size_t>(nms_out_ps.rank()), 2);
EXPECT_EQ(static_cast<size_t>(nms_out_ps[1]), 3);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment