Commit 160b91bf authored by Michał Karzyński's avatar Michał Karzyński Committed by Scott Cyphers

[Spec] Add 3-input constructor to DetectionOutput (#3966)

* Add 3-input constructor to DetectionOutput

* Review comments
parent f749c9d0
...@@ -33,6 +33,16 @@ op::DetectionOutput::DetectionOutput(const Output<Node>& box_logits, ...@@ -33,6 +33,16 @@ op::DetectionOutput::DetectionOutput(const Output<Node>& box_logits,
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
op::DetectionOutput::DetectionOutput(const Output<Node>& box_logits,
const Output<Node>& class_preds,
const Output<Node>& proposals,
const DetectionOutputAttrs& attrs)
: Op({box_logits, class_preds, proposals})
, m_attrs(attrs)
{
constructor_validate_and_infer_types();
}
void op::DetectionOutput::validate_and_infer_types() void op::DetectionOutput::validate_and_infer_types()
{ {
if (get_input_partial_shape(0).is_static()) if (get_input_partial_shape(0).is_static())
...@@ -50,6 +60,24 @@ void op::DetectionOutput::validate_and_infer_types() ...@@ -50,6 +60,24 @@ void op::DetectionOutput::validate_and_infer_types()
shared_ptr<Node> op::DetectionOutput::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::DetectionOutput::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
auto num_args = new_args.size();
NODE_VALIDATION_CHECK(
this, num_args == 3 || num_args == 5, "DetectionOutput accepts 3 or 5 inputs.");
if (num_args == 3)
{
return make_shared<DetectionOutput>( return make_shared<DetectionOutput>(
new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), new_args.at(4), m_attrs); new_args.at(0), new_args.at(1), new_args.at(2), m_attrs);
}
else
{
return make_shared<DetectionOutput>(new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
new_args.at(4),
m_attrs);
}
} }
...@@ -67,6 +67,17 @@ namespace ngraph ...@@ -67,6 +67,17 @@ namespace ngraph
const Output<Node>& aux_box_preds, const Output<Node>& aux_box_preds,
const DetectionOutputAttrs& attrs); const DetectionOutputAttrs& attrs);
/// \brief Constructs a DetectionOutput operation
///
/// \param box_logits Box logits
/// \param class_preds Class predictions
/// \param proposals Proposals
/// \param attrs Detection Output attributes
DetectionOutput(const Output<Node>& box_logits,
const Output<Node>& class_preds,
const Output<Node>& proposals,
const DetectionOutputAttrs& attrs);
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment