Commit 6c8b5650 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Added mode attributes to DynPad and TopK (#3093)

* Added mode attributes to DynPad and TopK

* Rename sort_type to sort

* Throw error in pad reference implementation for symmetric mode
parent ba226a14
...@@ -22,7 +22,8 @@ using namespace ngraph; ...@@ -22,7 +22,8 @@ using namespace ngraph;
op::DynPad::DynPad(const std::shared_ptr<Node>& arg, op::DynPad::DynPad(const std::shared_ptr<Node>& arg,
const std::shared_ptr<Node>& padding_below, const std::shared_ptr<Node>& padding_below,
const std::shared_ptr<Node>& padding_above, const std::shared_ptr<Node>& padding_above,
const std::shared_ptr<Node>& padding_value) const std::shared_ptr<Node>& padding_value,
op::PadMode pad_mode)
: Op("DynPad", check_single_output_args({arg, padding_below, padding_above, padding_value})) : Op("DynPad", check_single_output_args({arg, padding_below, padding_above, padding_value}))
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
...@@ -102,7 +103,8 @@ void op::DynPad::validate_and_infer_types() ...@@ -102,7 +103,8 @@ void op::DynPad::validate_and_infer_types()
shared_ptr<Node> op::DynPad::copy_with_new_args(const NodeVector& new_args) const shared_ptr<Node> op::DynPad::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<DynPad>(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3)); return make_shared<DynPad>(
new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), m_pad_mode);
} }
// TODO: This function is not implemented! // TODO: This function is not implemented!
......
...@@ -33,11 +33,14 @@ namespace ngraph ...@@ -33,11 +33,14 @@ namespace ngraph
/// \param padding_below The node producing the padding-below widths. /// \param padding_below The node producing the padding-below widths.
/// \param padding_above The node producing the padding-above widths. /// \param padding_above The node producing the padding-above widths.
/// \param padding_value The value to be used for padding. Must be scalar. /// \param padding_value The value to be used for padding. Must be scalar.
/// \param pad_mode The padding mode: CONSTANT(default), EDGE or REFLECT.
DynPad(const std::shared_ptr<Node>& arg, DynPad(const std::shared_ptr<Node>& arg,
const std::shared_ptr<Node>& padding_below, const std::shared_ptr<Node>& padding_below,
const std::shared_ptr<Node>& padding_above, const std::shared_ptr<Node>& padding_above,
const std::shared_ptr<Node>& padding_value); const std::shared_ptr<Node>& padding_value,
PadMode pad_mode = PadMode::CONSTANT);
PadMode get_pad_mode() const { return m_pad_mode; }
void validate_and_infer_types() override; void validate_and_infer_types() override;
virtual std::shared_ptr<Node> virtual std::shared_ptr<Node>
...@@ -46,6 +49,9 @@ namespace ngraph ...@@ -46,6 +49,9 @@ namespace ngraph
protected: protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
private:
PadMode m_pad_mode;
}; };
} }
} }
...@@ -34,11 +34,13 @@ op::TopK::TopK(const Output<Node>& arg, ...@@ -34,11 +34,13 @@ op::TopK::TopK(const Output<Node>& arg,
size_t top_k_axis, size_t top_k_axis,
const element::Type& index_element_type, const element::Type& index_element_type,
size_t k, size_t k,
bool compute_max) bool compute_max,
SortType sort)
: Op({arg, op::Constant::create(element::i64, Shape{1}, {k})->output(0)}) : Op({arg, op::Constant::create(element::i64, Shape{1}, {k})->output(0)})
, m_top_k_axis(top_k_axis) , m_top_k_axis(top_k_axis)
, m_index_element_type(index_element_type) , m_index_element_type(index_element_type)
, m_compute_max(compute_max) , m_compute_max(compute_max)
, m_sort(sort)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -47,11 +49,13 @@ op::TopK::TopK(const Output<Node>& arg, ...@@ -47,11 +49,13 @@ op::TopK::TopK(const Output<Node>& arg,
const Output<Node>& k, const Output<Node>& k,
size_t top_k_axis, size_t top_k_axis,
const element::Type& index_element_type, const element::Type& index_element_type,
bool compute_max) bool compute_max,
SortType sort)
: Op({arg, k}) : Op({arg, k})
, m_top_k_axis(top_k_axis) , m_top_k_axis(top_k_axis)
, m_index_element_type(index_element_type) , m_index_element_type(index_element_type)
, m_compute_max(compute_max) , m_compute_max(compute_max)
, m_sort(sort)
{ {
constructor_validate_and_infer_types(); constructor_validate_and_infer_types();
} }
...@@ -130,7 +134,7 @@ shared_ptr<Node> op::TopK::copy_with_new_args(const NodeVector& new_args) const ...@@ -130,7 +134,7 @@ shared_ptr<Node> op::TopK::copy_with_new_args(const NodeVector& new_args) const
{ {
check_new_args_count(this, new_args); check_new_args_count(this, new_args);
return make_shared<TopK>( return make_shared<TopK>(
new_args.at(0), new_args.at(1), m_top_k_axis, m_index_element_type, m_compute_max); new_args.at(0), new_args.at(1), m_top_k_axis, m_index_element_type, m_compute_max, m_sort);
} }
void op::TopK::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) void op::TopK::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
......
...@@ -30,6 +30,16 @@ namespace ngraph ...@@ -30,6 +30,16 @@ namespace ngraph
class TopK : public Op class TopK : public Op
{ {
public: public:
enum class SortType
{
// Returned values are not sorted
NONE,
// Sort result based on element indices
SORT_INDICES,
// Sort result based on element values
SORT_VALUES,
};
NGRAPH_API NGRAPH_API
static const std::string type_name; static const std::string type_name;
const std::string& description() const override { return type_name; } const std::string& description() const override { return type_name; }
...@@ -42,11 +52,13 @@ namespace ngraph ...@@ -42,11 +52,13 @@ namespace ngraph
/// \param index_element_type produce indices. Currently, only int64 or int32 are supported /// \param index_element_type produce indices. Currently, only int64 or int32 are supported
/// \param k Number of top indices to compute. Compute all indices if k = 0 /// \param k Number of top indices to compute. Compute all indices if k = 0
/// \param compute_max Compute top k max or top k min? /// \param compute_max Compute top k max or top k min?
/// \param sort SortType for sorting results, default - NONE
TopK(const Output<Node>& arg, TopK(const Output<Node>& arg,
size_t top_k_axis, size_t top_k_axis,
const element::Type& index_element_type, const element::Type& index_element_type,
size_t k = 0, size_t k = 0,
bool compute_max = true); bool compute_max = true,
SortType sort = SortType::NONE);
/// \brief Constructs a TopK operation. /// \brief Constructs a TopK operation.
/// ///
/// \param arg The input tensor /// \param arg The input tensor
...@@ -54,11 +66,13 @@ namespace ngraph ...@@ -54,11 +66,13 @@ namespace ngraph
/// \param top_k_axis The axis along which to compute top k indices /// \param top_k_axis The axis along which to compute top k indices
/// \param index_element_type produce indices. Currently, only int64 or int32 are supported /// \param index_element_type produce indices. Currently, only int64 or int32 are supported
/// \param compute_max Compute top k max or top k min? /// \param compute_max Compute top k max or top k min?
/// \param sort SortType for sorting results, default - NONE
TopK(const Output<Node>& arg, TopK(const Output<Node>& arg,
const Output<Node>& k, const Output<Node>& k,
size_t top_k_axis, size_t top_k_axis,
const element::Type& index_element_type, const element::Type& index_element_type,
bool compute_max = true); bool compute_max = true,
SortType sort = SortType::NONE);
void validate_and_infer_types() override; void validate_and_infer_types() override;
...@@ -71,10 +85,12 @@ namespace ngraph ...@@ -71,10 +85,12 @@ namespace ngraph
size_t get_top_k_axis() const { return m_top_k_axis; } size_t get_top_k_axis() const { return m_top_k_axis; }
element::Type get_index_element_type() const { return m_index_element_type; } element::Type get_index_element_type() const { return m_index_element_type; }
bool get_compute_max() const { return m_compute_max; } bool get_compute_max() const { return m_compute_max; }
SortType get_sort() const { return m_sort; }
protected: protected:
size_t m_top_k_axis{0}; size_t m_top_k_axis{0};
element::Type m_index_element_type; element::Type m_index_element_type;
bool m_compute_max{false}; bool m_compute_max{false};
SortType m_sort;
virtual void generate_adjoints(autodiff::Adjoints& adjoints, virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override; const NodeVector& deltas) override;
}; };
......
...@@ -27,7 +27,8 @@ namespace ngraph ...@@ -27,7 +27,8 @@ namespace ngraph
{ {
CONSTANT = 0, CONSTANT = 0,
EDGE, EDGE,
REFLECT REFLECT,
SYMMETRIC
}; };
/// \brief Padding Type used for `Convolution` and `Pooling` /// \brief Padding Type used for `Convolution` and `Pooling`
......
...@@ -162,6 +162,11 @@ namespace ngraph ...@@ -162,6 +162,11 @@ namespace ngraph
v = arg0[input_transform.index(c)]; v = arg0[input_transform.index(c)];
break; break;
} }
case op::PadMode::SYMMETRIC:
{
// TODO: Add support for Symmetric mode
throw ngraph_error("Symmetric mode padding not supported");
}
} }
out[output_transform.index(out_coord)] = v; out[output_transform.index(out_coord)] = v;
......
...@@ -10028,6 +10028,7 @@ TEST(type_prop, topk_rank_dynamic_ok) ...@@ -10028,6 +10028,7 @@ TEST(type_prop, topk_rank_dynamic_ok)
ASSERT_TRUE(topk->get_output_element_type(1) == element::f32); ASSERT_TRUE(topk->get_output_element_type(1) == element::f32);
ASSERT_TRUE(topk->get_output_partial_shape(0).rank().is_dynamic()); ASSERT_TRUE(topk->get_output_partial_shape(0).rank().is_dynamic());
ASSERT_TRUE(topk->get_output_partial_shape(1).rank().is_dynamic()); ASSERT_TRUE(topk->get_output_partial_shape(1).rank().is_dynamic());
ASSERT_TRUE(topk->get_sort() == op::TopK::SortType::NONE);
} }
TEST(type_prop, topk_rank_dynamic_result_et_dynamic) TEST(type_prop, topk_rank_dynamic_result_et_dynamic)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment