Commit e5bc0854 authored by Adam Rogowiec's avatar Adam Rogowiec Committed by Scott Cyphers

[SPEC] Anchors attribute for RegionYolo operator. (#3854)

* Add anchors attribute to RegionYolo operator.

* Make op doc more consistent.

* Make anchors optional attribute.
parent aa9a2b91
...@@ -28,13 +28,15 @@ op::RegionYolo::RegionYolo(const Output<Node>& input, ...@@ -28,13 +28,15 @@ op::RegionYolo::RegionYolo(const Output<Node>& input,
const bool do_softmax, const bool do_softmax,
const vector<int64_t>& mask, const vector<int64_t>& mask,
const int axis, const int axis,
const int end_axis) const int end_axis,
const vector<float>& anchors)
: Op({input}) : Op({input})
, m_num_coords(num_coords) , m_num_coords(num_coords)
, m_num_classes(num_classes) , m_num_classes(num_classes)
, m_num_regions(num_regions) , m_num_regions(num_regions)
, m_do_softmax(do_softmax) , m_do_softmax(do_softmax)
, m_mask(mask) , m_mask(mask)
, m_anchors(anchors)
, m_axis(axis) , m_axis(axis)
, m_end_axis(end_axis) , m_end_axis(end_axis)
{ {
...@@ -96,5 +98,6 @@ shared_ptr<Node> op::RegionYolo::copy_with_new_args(const NodeVector& new_args) ...@@ -96,5 +98,6 @@ shared_ptr<Node> op::RegionYolo::copy_with_new_args(const NodeVector& new_args)
m_do_softmax, m_do_softmax,
m_mask, m_mask,
m_axis, m_axis,
m_end_axis); m_end_axis,
m_anchors);
} }
...@@ -28,16 +28,20 @@ namespace ngraph ...@@ -28,16 +28,20 @@ namespace ngraph
NGRAPH_API NGRAPH_API
static constexpr NodeTypeInfo type_info{"RegionYolo", 0}; static constexpr NodeTypeInfo type_info{"RegionYolo", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; } const NodeTypeInfo& get_type_info() const override { return type_info; }
///
/// \brief Constructs a RegionYolo operation /// \brief Constructs a RegionYolo operation
/// ///
/// \param input Input /// \param[in] input Input
/// \param num_coords Number of coordinates for each region /// \param[in] num_coords Number of coordinates for each region
/// \param num_classes Number of classes for each region /// \param[in] num_classes Number of classes for each region
/// \param num_regions Number of regions /// \param[in] num_regions Number of regions
/// \param do_softmax Compute softmax /// \param[in] do_softmax Compute softmax
/// \param mask Mask /// \param[in] mask Mask
/// \param axis Axis to begin softmax on /// \param[in] axis Axis to begin softmax on
/// \param end_axis Axis to end softmax on /// \param[in] end_axis Axis to end softmax on
/// \param[in] anchors A flattened list of pairs `[width, height]` that describes
/// prior box sizes.
///
RegionYolo(const Output<Node>& input, RegionYolo(const Output<Node>& input,
const size_t num_coords, const size_t num_coords,
const size_t num_classes, const size_t num_classes,
...@@ -45,7 +49,8 @@ namespace ngraph ...@@ -45,7 +49,8 @@ namespace ngraph
const bool do_softmax, const bool do_softmax,
const std::vector<int64_t>& mask, const std::vector<int64_t>& mask,
const int axis, const int axis,
const int end_axis); const int end_axis,
const std::vector<float>& anchors = std::vector<float>{});
void validate_and_infer_types() override; void validate_and_infer_types() override;
...@@ -57,6 +62,7 @@ namespace ngraph ...@@ -57,6 +62,7 @@ namespace ngraph
size_t get_num_regions() const { return m_num_regions; } size_t get_num_regions() const { return m_num_regions; }
bool get_do_softmax() const { return m_do_softmax; } bool get_do_softmax() const { return m_do_softmax; }
const std::vector<int64_t>& get_mask() const { return m_mask; } const std::vector<int64_t>& get_mask() const { return m_mask; }
const std::vector<float>& get_anchors() const { return m_anchors; }
int get_axis() const { return m_axis; } int get_axis() const { return m_axis; }
int get_end_axis() const { return m_end_axis; } int get_end_axis() const { return m_end_axis; }
private: private:
...@@ -65,6 +71,7 @@ namespace ngraph ...@@ -65,6 +71,7 @@ namespace ngraph
size_t m_num_regions; size_t m_num_regions;
bool m_do_softmax; bool m_do_softmax;
std::vector<int64_t> m_mask; std::vector<int64_t> m_mask;
std::vector<float> m_anchors{};
int m_axis; int m_axis;
int m_end_axis; int m_end_axis;
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment