Commit 6c8b5650 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Added mode attributes to DynPad and TopK (#3093)

* Added mode attributes to DynPad and TopK

* Rename sort_type to sort

* Throw error in pad reference implementation for symmetric mode
parent ba226a14
......@@ -22,7 +22,8 @@ using namespace ngraph;
op::DynPad::DynPad(const std::shared_ptr<Node>& arg,
const std::shared_ptr<Node>& padding_below,
const std::shared_ptr<Node>& padding_above,
const std::shared_ptr<Node>& padding_value)
const std::shared_ptr<Node>& padding_value,
op::PadMode pad_mode)
: Op("DynPad", check_single_output_args({arg, padding_below, padding_above, padding_value}))
{
constructor_validate_and_infer_types();
......@@ -102,7 +103,8 @@ void op::DynPad::validate_and_infer_types()
shared_ptr<Node> op::DynPad::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<DynPad>(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3));
return make_shared<DynPad>(
new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), m_pad_mode);
}
// TODO: This function is not implemented!
......
......@@ -33,11 +33,14 @@ namespace ngraph
/// \param padding_below The node producing the padding-below widths.
/// \param padding_above The node producing the padding-above widths.
/// \param padding_value The value to be used for padding. Must be scalar.
/// \param pad_mode The padding mode: CONSTANT(default), EDGE or REFLECT.
DynPad(const std::shared_ptr<Node>& arg,
const std::shared_ptr<Node>& padding_below,
const std::shared_ptr<Node>& padding_above,
const std::shared_ptr<Node>& padding_value);
const std::shared_ptr<Node>& padding_value,
PadMode pad_mode = PadMode::CONSTANT);
PadMode get_pad_mode() const { return m_pad_mode; }
void validate_and_infer_types() override;
virtual std::shared_ptr<Node>
......@@ -46,6 +49,9 @@ namespace ngraph
protected:
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
private:
PadMode m_pad_mode;
};
}
}
......@@ -34,11 +34,13 @@ op::TopK::TopK(const Output<Node>& arg,
size_t top_k_axis,
const element::Type& index_element_type,
size_t k,
bool compute_max)
bool compute_max,
SortType sort)
: Op({arg, op::Constant::create(element::i64, Shape{1}, {k})->output(0)})
, m_top_k_axis(top_k_axis)
, m_index_element_type(index_element_type)
, m_compute_max(compute_max)
, m_sort(sort)
{
constructor_validate_and_infer_types();
}
......@@ -47,11 +49,13 @@ op::TopK::TopK(const Output<Node>& arg,
const Output<Node>& k,
size_t top_k_axis,
const element::Type& index_element_type,
bool compute_max)
bool compute_max,
SortType sort)
: Op({arg, k})
, m_top_k_axis(top_k_axis)
, m_index_element_type(index_element_type)
, m_compute_max(compute_max)
, m_sort(sort)
{
constructor_validate_and_infer_types();
}
......@@ -130,7 +134,7 @@ shared_ptr<Node> op::TopK::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<TopK>(
new_args.at(0), new_args.at(1), m_top_k_axis, m_index_element_type, m_compute_max);
new_args.at(0), new_args.at(1), m_top_k_axis, m_index_element_type, m_compute_max, m_sort);
}
void op::TopK::generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas)
......
......@@ -30,6 +30,16 @@ namespace ngraph
class TopK : public Op
{
public:
enum class SortType
{
// Returned values are not sorted
NONE,
// Sort result based on element indices
SORT_INDICES,
// Sort result based on element values
SORT_VALUES,
};
NGRAPH_API
static const std::string type_name;
const std::string& description() const override { return type_name; }
......@@ -42,11 +52,13 @@ namespace ngraph
/// \param index_element_type produce indices. Currently, only int64 or int32 are supported
/// \param k Number of top indices to compute. Compute all indices if k = 0
/// \param compute_max Compute top k max or top k min?
/// \param sort SortType for sorting results, default - NONE
TopK(const Output<Node>& arg,
size_t top_k_axis,
const element::Type& index_element_type,
size_t k = 0,
bool compute_max = true);
bool compute_max = true,
SortType sort = SortType::NONE);
/// \brief Constructs a TopK operation.
///
/// \param arg The input tensor
......@@ -54,11 +66,13 @@ namespace ngraph
/// \param top_k_axis The axis along which to compute top k indices
/// \param index_element_type produce indices. Currently, only int64 or int32 are supported
/// \param compute_max Compute top k max or top k min?
/// \param sort SortType for sorting results, default - NONE
TopK(const Output<Node>& arg,
const Output<Node>& k,
size_t top_k_axis,
const element::Type& index_element_type,
bool compute_max = true);
bool compute_max = true,
SortType sort = SortType::NONE);
void validate_and_infer_types() override;
......@@ -71,10 +85,12 @@ namespace ngraph
size_t get_top_k_axis() const { return m_top_k_axis; }
element::Type get_index_element_type() const { return m_index_element_type; }
bool get_compute_max() const { return m_compute_max; }
SortType get_sort() const { return m_sort; }
protected:
size_t m_top_k_axis{0};
element::Type m_index_element_type;
bool m_compute_max{false};
SortType m_sort;
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
};
......
......@@ -27,7 +27,8 @@ namespace ngraph
{
CONSTANT = 0,
EDGE,
REFLECT
REFLECT,
SYMMETRIC
};
/// \brief Padding Type used for `Convolution` and `Pooling`
......
......@@ -162,6 +162,11 @@ namespace ngraph
v = arg0[input_transform.index(c)];
break;
}
case op::PadMode::SYMMETRIC:
{
// TODO: Add support for Symmetric mode
throw ngraph_error("Symmetric mode padding not supported");
}
}
out[output_transform.index(out_coord)] = v;
......
......@@ -10028,6 +10028,7 @@ TEST(type_prop, topk_rank_dynamic_ok)
ASSERT_TRUE(topk->get_output_element_type(1) == element::f32);
ASSERT_TRUE(topk->get_output_partial_shape(0).rank().is_dynamic());
ASSERT_TRUE(topk->get_output_partial_shape(1).rank().is_dynamic());
ASSERT_TRUE(topk->get_sort() == op::TopK::SortType::NONE);
}
TEST(type_prop, topk_rank_dynamic_result_et_dynamic)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment