Commit 8b091114 authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Added extra attributes to DynSlice (#2862)

* Add more slicing attributes to dynslice

* Added output shape computation for dyn slice

* Bug fixes and added unit tests

* Style fix

* Addressed PR feedback
parent dddcd4a8
......@@ -144,6 +144,29 @@ vector<string> op::Constant::get_value_strings() const
return rc;
}
Shape op::Constant::get_shape_val() const
{
NGRAPH_CHECK(m_element_type == element::i64);
std::vector<int64_t> out_shape = get_vector<int64_t>();
Shape output_shape(shape_size(m_shape));
std::transform(out_shape.begin(), out_shape.end(), output_shape.begin(), [&](const int64_t& v) {
return (v > 0) ? v : 0;
});
return output_shape;
}
Strides op::Constant::get_strides_val() const
{
NGRAPH_CHECK(m_element_type == element::i64);
std::vector<int64_t> out_strides = get_vector<int64_t>();
Strides output_strides(shape_size(m_shape));
std::transform(out_strides.begin(),
out_strides.end(),
output_strides.begin(),
[&](const int64_t& v) { return (v > 0) ? v : 0; });
return output_strides;
}
shared_ptr<Node> op::Constant::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
......
......@@ -152,6 +152,15 @@ namespace ngraph
set_output_type(0, m_element_type, m_shape);
}
/// \brief Returns the value of the constant node as a Shape object
/// Can only be used on element::i64 nodes and interprets
/// negative values as zeros.
Shape get_shape_val() const;
/// \brief Returns the value of the constant node as a Strides object
/// Can only be used on element::i64 nodes and interprets
/// negative values as zeros.
Strides get_strides_val() const;
/// \brief Wrapper around constructing a shared_ptr of a Constant
///
/// \param type The element type of the tensor constant.
......
......@@ -15,6 +15,9 @@
//*****************************************************************************
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/constant.hpp"
#include <memory>
using namespace std;
......@@ -23,12 +26,158 @@ using namespace ngraph;
op::DynSlice::DynSlice(const shared_ptr<Node>& arg,
const shared_ptr<Node>& lower_bounds,
const shared_ptr<Node>& upper_bounds,
const shared_ptr<Node>& strides)
const shared_ptr<Node>& strides,
const AxisSet& lower_bounds_mask,
const AxisSet& upper_bounds_mask,
const AxisSet& new_axis,
const AxisSet& shrink_axis,
const AxisSet& ellipsis_mask)
: Op("DynSlice", check_single_output_args({arg, lower_bounds, upper_bounds, strides}))
, m_lower_bounds_mask(lower_bounds_mask)
, m_upper_bounds_mask(upper_bounds_mask)
, m_new_axis(new_axis)
, m_shrink_axis(shrink_axis)
, m_ellipsis_mask(ellipsis_mask)
{
constructor_validate_and_infer_types();
}
Shape op::DynSlice::compute_output_shape() const
{
auto input_shape = get_input_partial_shape(0).to_shape();
auto lower_bounds = dynamic_pointer_cast<op::Constant>(get_argument(1));
auto upper_bounds = dynamic_pointer_cast<op::Constant>(get_argument(2));
auto strides = dynamic_pointer_cast<op::Constant>(get_argument(3));
if (lower_bounds && upper_bounds && strides)
{
auto lb = lower_bounds->get_vector<int64_t>();
auto ub = upper_bounds->get_vector<int64_t>();
auto str = strides->get_vector<int64_t>();
int max_dims = input_shape.size() + m_new_axis.size();
if (lb.size() && ub.size())
{
NODE_VALIDATION_CHECK(
this,
lb.size() == ub.size(),
"Lower bounds and Upper bounds needs to have same number of values");
}
if (lb.size() && str.size())
{
NODE_VALIDATION_CHECK(this,
lb.size() == str.size(),
"Lower bounds and strides needs to have same number of values");
}
if (ub.size() && str.size())
{
NODE_VALIDATION_CHECK(this,
ub.size() == str.size(),
"Upper bounds and strides needs to have same number of values");
}
int bounds_size =
lb.size() ? lb.size() : (ub.size() ? ub.size() : (str.size() ? str.size() : 0));
NODE_VALIDATION_CHECK(
this, m_ellipsis_mask.size() <= 1, "Ellipsis mask cannot specify more than one axis");
int ellipsis_pos1 = m_ellipsis_mask.size() ? *m_ellipsis_mask.begin() : max_dims;
int ellipsis_pos2 = max_dims;
bounds_size -= ellipsis_pos1;
if (bounds_size > 0 && (max_dims - bounds_size) > ellipsis_pos1)
{
ellipsis_pos2 = max_dims - bounds_size;
}
std::vector<int> begin_dms(max_dims, 0);
std::vector<int> end_dms(max_dims, -1);
std::vector<int> stride_dms(max_dims, 1);
int i, j, k, bj, ej, sj;
Shape out_dims;
for (i = 0, j = 0, k = 0, bj = 0, ej = 0, sj = 0; i < max_dims; i++)
{
if (i >= ellipsis_pos1 && i < ellipsis_pos2)
{
if (m_new_axis.find(i) == m_new_axis.end())
{
end_dms[i] = end_dms[i] >= 0 ? end_dms[i] : input_shape[j++] + end_dms[i];
}
else
{
end_dms[i] = begin_dms[i];
}
out_dims.push_back(
static_cast<int>(ceil(static_cast<float>(abs(end_dms[i] - begin_dms[i]) + 1) /
static_cast<float>(abs(stride_dms[i])))));
k = ellipsis_pos1;
continue;
}
stride_dms[i] = (str.size() > sj && str[sj] != 0) ? str[sj++] : 1;
// Use lower_bounds if mask is not set
if (m_lower_bounds_mask.find(j) == m_lower_bounds_mask.end())
{
begin_dms[i] = lb.size() > bj ? lb[bj] : (stride_dms[i] > 0 ? 0 : -1);
}
else
{
begin_dms[i] = stride_dms[i] > 0 ? 0 : -1;
}
bj++;
begin_dms[i] = begin_dms[i] >= 0 ? begin_dms[i] : input_shape[j] + begin_dms[i];
// Clipping 'begin'
begin_dms[i] =
(begin_dms[i] < 0) ? 0 : (begin_dms[i] >= input_shape[j] ? input_shape[j] - 1
: begin_dms[i]);
// Use upper_bounds if mask is not set
if (m_upper_bounds_mask.find(j) == m_upper_bounds_mask.end())
{
int end_dms_tmp =
ub.size() > ej ? (stride_dms[i] > 0 ? ub[ej] - 1 : ub[ej] + 1) : end_dms[i];
end_dms[i] = ub.size() > ej ? end_dms_tmp : (stride_dms[i] > 0 ? -1 : 0);
}
else
{
end_dms[i] = stride_dms[i] > 0 ? -1 : 0;
}
ej++;
end_dms[i] = end_dms[i] >= 0 ? end_dms[i] : input_shape[j] + end_dms[i];
// Clipping 'end'
end_dms[i] = (end_dms[i] < 0) ? 0 : (end_dms[i] >= input_shape[j] ? input_shape[j] - 1
: end_dms[i]);
if (m_new_axis.find(i) == m_new_axis.end())
{
j++;
}
else
{
end_dms[i] = 0;
}
if (m_shrink_axis.find(k) != m_shrink_axis.end())
{
end_dms[i] = begin_dms[i];
}
else
{
out_dims.push_back(
static_cast<int>(ceil(static_cast<float>(abs(end_dms[i] - begin_dms[i]) + 1) /
static_cast<float>(abs(stride_dms[i])))));
}
k++;
}
return out_dims;
}
return Shape{};
}
void op::DynSlice::validate_and_infer_types()
{
auto lower_bounds_et = get_input_element_type(1);
......@@ -66,20 +215,26 @@ void op::DynSlice::validate_and_infer_types()
strides_shape.rank(),
".");
NODE_VALIDATION_CHECK(this,
lower_bounds_shape.compatible(PartialShape{arg_shape.rank()}),
"Lower bounds must have shape [n], where n is the rank of arg.");
NODE_VALIDATION_CHECK(this,
upper_bounds_shape.compatible(PartialShape{arg_shape.rank()}),
"Upper bounds must have shape [n], where n is the rank of arg.");
NODE_VALIDATION_CHECK(this,
strides_shape.compatible(PartialShape{arg_shape.rank()}),
"Strides shape must have shape [n], where n is the rank of arg.");
set_input_is_relevant_to_shape(1);
set_input_is_relevant_to_shape(2);
set_input_is_relevant_to_shape(3);
set_output_type(0, get_input_element_type(0), PartialShape::dynamic(arg_shape.rank()));
if (get_input_partial_shape(0).is_static())
{
auto shape = compute_output_shape();
if (shape != Shape{})
{
set_output_type(0, get_input_element_type(0), shape);
}
else
{
set_output_type(0, get_input_element_type(0), PartialShape::dynamic(arg_shape.rank()));
}
}
else
{
set_output_type(0, get_input_element_type(0), PartialShape::dynamic(arg_shape.rank()));
}
}
shared_ptr<Node> op::DynSlice::copy_with_new_args(const NodeVector& new_args) const
......
......@@ -34,10 +34,20 @@ namespace ngraph
/// \param upper_bounds The axiswise upper bounds of the slice (exclusive).
/// \param strides The slicing strides; for example, strides of `{n,m}` means to take
/// every nth row and every mth column of the input matrix.
/// \param lower_bounds_mask Ignores lower_bounds for axis with the mask set
/// \param upper_bounds_mask Ignores upper_bounds for axis with the mask set
/// \param new_axis Add dimension one axis at the set positions
/// \param shrink_axis Delete dimensions at the set positions
/// \param ellipsis_mask Inserts missing dimensions on the set position
DynSlice(const std::shared_ptr<Node>& arg,
const std::shared_ptr<Node>& lower_bounds,
const std::shared_ptr<Node>& upper_bounds,
const std::shared_ptr<Node>& strides);
const std::shared_ptr<Node>& strides,
const AxisSet& lower_bounds_mask = AxisSet{},
const AxisSet& upper_bounds_mask = AxisSet{},
const AxisSet& new_axis = AxisSet{},
const AxisSet& shrink_axis = AxisSet{},
const AxisSet& ellipsis_mask = AxisSet{});
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
......@@ -46,6 +56,16 @@ namespace ngraph
virtual void generate_adjoints(autodiff::Adjoints& adjoints,
const NodeVector& deltas) override;
void validate_and_infer_types() override;
private:
/// Helper method to compute output shape
Shape compute_output_shape() const;
AxisSet m_lower_bounds_mask;
AxisSet m_upper_bounds_mask;
AxisSet m_new_axis;
AxisSet m_shrink_axis;
AxisSet m_ellipsis_mask;
};
}
}
......@@ -79,7 +79,7 @@ void op::PriorBox::validate_and_infer_types()
"Layer shape must have rank 2",
const_shape->get_shape());
auto layer_shape = static_cast<const int64_t*>(const_shape->get_data_ptr());
auto layer_shape = const_shape->get_shape_val();
size_t num_priors = 0;
// {Prior boxes, Variance-adjusted prior boxes}
if (m_scale_all)
......
......@@ -91,7 +91,7 @@ void op::PriorBoxClustered::validate_and_infer_types()
"Layer shape must have rank 2",
const_shape->get_shape());
auto layer_shape = static_cast<const int64_t*>(const_shape->get_data_ptr());
auto layer_shape = const_shape->get_shape_val();
// {Prior boxes, variances-adjusted prior boxes}
set_output_type(
0, element::f32, Shape{2, 4 * layer_shape[0] * layer_shape[1] * m_num_priors});
......
......@@ -75,7 +75,7 @@ void op::Proposal::validate_and_infer_types()
"Layer shape must have rank 2",
const_shape->get_shape());
auto image_shape = static_cast<const int64_t*>(const_shape->get_data_ptr());
auto image_shape = const_shape->get_shape_val();
set_output_type(0, element::f32, Shape{image_shape[0] * m_post_nms_topn, 5});
}
......
......@@ -19,6 +19,7 @@
#include <tuple>
#include "ngraph/coordinate_diff.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/attr_types.hpp"
......
......@@ -13100,6 +13100,106 @@ TEST(type_prop, dynslice_arg_rank_static_dynamic_params_rank_dynamic_ok)
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic(4)));
}
TEST(type_prop, dynslice_static_shape)
{
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 3, 4, 5, 6});
auto lower_bounds = op::Constant::create(element::i64, Shape{5}, {0, 1, 2, 3, 1});
auto upper_bounds = op::Constant::create(element::i64, Shape{5}, {1, 3, 3, 5, 6});
auto strides = op::Constant::create(element::i64, Shape{5}, {1, 1, 1, 2, 2});
auto r = make_shared<op::DynSlice>(arg, lower_bounds, upper_bounds, strides);
EXPECT_EQ(r->get_output_element_type(0), element::f32);
EXPECT_EQ(r->get_shape(), (Shape{1, 2, 1, 1, 3}));
}
struct DynSliceParams
{
std::vector<Shape> shapes;
std::vector<std::vector<int64_t>> vals;
std::vector<AxisSet> attrs;
DynSliceParams(const std::vector<Shape>& shape,
const std::vector<std::vector<int64_t>>& val,
const std::vector<AxisSet>& attr)
: shapes(shape)
, vals(val)
, attrs(attr)
{
}
};
struct DeduceDynSliceTest : ::testing::TestWithParam<DynSliceParams>
{
};
TEST_P(DeduceDynSliceTest, output_shape)
{
auto tp = GetParam();
auto arg = make_shared<op::Parameter>(element::f32, tp.shapes[0]);
auto lower_bounds = op::Constant::create(element::i64, tp.shapes[1], tp.vals[0]);
auto upper_bounds = op::Constant::create(element::i64, tp.shapes[2], tp.vals[1]);
auto strides = op::Constant::create(element::i64, tp.shapes[3], tp.vals[2]);
auto r = make_shared<op::DynSlice>(arg,
lower_bounds,
upper_bounds,
strides,
tp.attrs[0],
tp.attrs[1],
tp.attrs[2],
tp.attrs[3],
tp.attrs[4]);
EXPECT_EQ(r->get_shape(), tp.shapes[4]);
}
INSTANTIATE_TEST_CASE_P(
type_prop,
DeduceDynSliceTest,
::testing::Values(
DynSliceParams({{2, 3, 4, 5, 6}, {5}, {5}, {5}, {1, 2, 1, 1, 3}},
{{0, 1, 2, 3, 1}, {1, 3, 3, 5, 6}, {1, 1, 1, 2, 2}},
{{}, {}, {}, {}, {}}),
DynSliceParams({{10}, {0}, {0}, {0}, {10}}, {{}, {}, {}}, {{}, {}, {}, {}, {}}),
DynSliceParams({{10}, {1}, {1}, {0}, {10}},
{{0}, {0}, {}},
{{}, {0}, {}, {}, {}}), // end-mask
DynSliceParams({{10}, {1}, {1}, {0}, {9}},
{{-1}, {-1}, {}},
{{0}, {}, {}, {}, {}}), // begin-mask
DynSliceParams({{10}, {1}, {1}, {0}, {10}}, {{0}, {10}, {}}, {{}, {}, {}, {}, {}}),
DynSliceParams({{10}, {1}, {1}, {0}, {5}}, {{5}, {10}, {}}, {{}, {}, {}, {}, {}}),
DynSliceParams({{10}, {1}, {1}, {0}, {5}}, {{-5}, {10}, {}}, {{}, {}, {}, {}, {}}),
DynSliceParams({{10}, {1}, {1}, {1}, {6}},
{{-5}, {0}, {-1}}, // negative-stride
{{}, {0}, {}, {}, {}}),
DynSliceParams({{10}, {1}, {1}, {1}, {3}}, {{-5}, {2}, {-1}}, {{}, {}, {}, {}, {}}),
DynSliceParams({{10}, {1}, {1}, {1}, {5}}, {{0}, {0}, {2}}, {{}, {0}, {}, {}, {}}),
DynSliceParams({{10}, {1}, {1}, {1}, {5}}, {{1}, {0}, {2}}, {{}, {0}, {}, {}, {}}),
DynSliceParams({{10}, {1}, {1}, {1}, {10}}, {{-1}, {0}, {-1}}, {{}, {0}, {}, {}, {}}),
DynSliceParams({{10}, {1}, {1}, {1}, {5}}, {{-1}, {0}, {-2}}, {{}, {0}, {}, {}, {}}),
/* Axis Masks: New, Shrink, Ellipsis */
DynSliceParams({{10}, {1}, {1}, {0}, {1, 10}}, {{0}, {10}, {}}, {{}, {}, {0}, {}, {}}),
DynSliceParams({{1, 2, 3}, {2}, {2}, {0}, {1, 2, 2}},
{{0, 0}, {1, 2}, {}},
{{}, {}, {}, {}, {1}}),
DynSliceParams({{1, 2, 3}, {4}, {4}, {0}, {1, 2, 1}},
{{0, 0, 0, 1}, {2, 3, 2, 2}, {}},
{{}, {}, {2}, {3}, {}}),
DynSliceParams({{1, 2, 3}, {3}, {3}, {0}, {1, 1, 2, 1}},
{{0, 0, 1}, {2, 2, 2}, {}},
{{}, {}, {0}, {}, {1}}),
DynSliceParams({{1, 2, 2, 2}, {1}, {1}, {1}, {1, 2, 2}},
{{-1}, {0}, {-2}},
{{1}, {1}, {}, {1}, {}}),
DynSliceParams({{1, 2, 2, 2}, {4}, {4}, {0}, {1, 2, 2}},
{{0, 1, 0, 0}, {1, 2, 2, 2}, {}},
{{1}, {1}, {}, {1}, {}}),
DynSliceParams({{1, 2, 3}, {3}, {3}, {0}, {1, 1, 2}},
{{0, 0, 1}, {2, 2, 2}, {}},
{{}, {}, {0}, {2}, {1}})));
void DynSlice_Test_Shape_Except(const shared_ptr<Node>& param_0,
const shared_ptr<Node>& param_1,
const shared_ptr<Node>& param_2,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment