Commit 080d4f95 authored by Adam Procter's avatar Adam Procter

Implement DynElimination for DynSlice; simple test passing, but more needed

parent f561c937
......@@ -15,8 +15,11 @@
//*****************************************************************************
#include "dyn_elimination.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
......@@ -27,6 +30,7 @@ pass::DynElimination::DynElimination()
: GraphRewrite()
{
construct_transpose();
construct_dyn_reshape();
}
void pass::DynElimination::construct_transpose()
......@@ -74,3 +78,302 @@ void pass::DynElimination::construct_transpose()
auto transpose_matcher = make_shared<pattern::Matcher>(transpose, "DynElimination.Transpose");
add_matcher(transpose_matcher, transpose_callback, all_pass_property_off);
}
//
// We eliminate DynSlice by converting it to a sequence of ops:
//
// Slice (to do the basic slicing)
// |
// v
// Reshape (non-transposing, to handle shrinks)
// |
// vconst
// Reverse (to emulate backwards stride)
//
// (The Reshape, Reverse, or both may be omitted if they would just be identities.)
//
// A SlicePlan is used to collect parameters for these ops.
//
struct SlicePlan
{
// Parameters for the Slice
std::vector<int64_t> begins;
std::vector<int64_t> ends;
std::vector<int64_t> strides;
// Shapes coming into, and going out of, the Reshape.
Shape reshape_in_shape;
Shape reshape_out_shape;
// Parameters for the Reverse
std::set<size_t> reverse_axes;
};
static SlicePlan make_plan(const Shape& input_shape,
const std::vector<int64_t>& begins,
const std::vector<int64_t>& ends,
const std::vector<int64_t>& strides,
const AxisSet& lower_bounds_mask,
const AxisSet& upper_bounds_mask,
const AxisSet& new_axis_mask,
const AxisSet& shrink_axis_mask,
const AxisSet& ellipsis_mask)
{
NGRAPH_CHECK(begins.size() == ends.size());
NGRAPH_CHECK(ends.size() == strides.size());
size_t num_slice_indices = begins.size();
size_t num_real_axes = 0;
size_t num_shrink_axes = 0;
size_t num_new_axes = 0;
bool ellipsis_found = false;
// Make a pass over the original slices to make sure there is at most one
// ellipsis, and to count up the number of shrink axes, the number of
// "newaxis"es, and the number of "real" axes (axes that are not newaxis
// and are not the ellipsis).
for (size_t i = 0; i < num_slice_indices; i++)
{
if (ellipsis_mask.count(i))
{
NGRAPH_CHECK(!ellipsis_found);
ellipsis_found = true;
}
else if (new_axis_mask.count(i))
{
num_new_axes++;
}
else
{
if (shrink_axis_mask.count(i))
{
num_shrink_axes++;
}
num_real_axes++;
}
}
NGRAPH_CHECK(num_real_axes <= input_shape.size());
// Figure out how many axes need to be inserted when the ellipsis (which
// may be an implicit ellipsis at the end) is expanded.
size_t ellipsis_size = input_shape.size() - num_real_axes;
// Initialize our slice plan.
SlicePlan p;
p.begins = std::vector<int64_t>(num_real_axes + ellipsis_size);
p.ends = std::vector<int64_t>(num_real_axes + ellipsis_size);
p.strides = std::vector<int64_t>(num_real_axes + ellipsis_size);
p.reshape_in_shape = Shape(num_real_axes + ellipsis_size);
p.reshape_out_shape = Shape(num_new_axes + num_real_axes + ellipsis_size - num_shrink_axes);
p.reverse_axes = AxisSet{};
// Begin a maddeningly delicate loop to desugar the original slice specs.
//
// * i_in is iterating over the axes of the input shape, which are also the axes of
// p.reshape_in_shape.
// * i_out is iterating over the axes of p.reshape_out_shape
size_t i_in = 0;
size_t i_out = 0;
// If no actual ellipsis exists, there is an "implicit" one at the end,
// which we will handle after the loop. So the logic is wrapped up here,
// allowing it to be used both during and after the loop.
auto expand_ellipsis = [&]() {
for (size_t i = 0; i < ellipsis_size; i++)
{
p.begins[i_in] = 0;
p.ends[i_in] = int64_t(input_shape[i_in]);
p.strides[i_in] = 1;
p.reshape_in_shape[i_in] = input_shape[i_in];
p.reshape_out_shape[i_out] = input_shape[i_in];
i_in++;
i_out++;
}
};
for (size_t i = 0; i < num_slice_indices; i++)
{
// If this is a "newaxis", we throw a 1 into the final shape, but it
// will not be present in the intermediate shape and does not
// correspond to anything in the original shape.
if (new_axis_mask.count(i))
{
p.reshape_out_shape[i_out] = 1;
i_out++;
}
// If this is a "shrunken" axis, the intermediate shape will have a
// "1" here, but nothing will be there in the final shape.
else if (shrink_axis_mask.count(i))
{
int64_t begin = begins[i];
// Note that clipping is not used for "shrunken" axes: an
// out-of-bounds index is an error.
NGRAPH_CHECK(begin >= -(int64_t(input_shape[i_in])) &&
begin < int64_t(input_shape[i_in]));
if (begin < 0)
{
begin += int64_t(input_shape[i_in]);
}
p.begins[i_in] = begin;
p.ends[i_in] = begin + 1;
p.strides[i_in] = 1;
p.reshape_in_shape[i_in] = 1;
i_in++;
}
// If this is the ellipsis, expand it (see expand_ellipsis above for
// details).
else if (ellipsis_mask.count(i))
{
expand_ellipsis();
}
// In other cases, we have a nice, ordinary (begin:end:stride) slice.
// We need to adjust for begin/end being masked, and begin/end/stride
// being negative or out of bounds.
else
{
bool is_reverse = strides[i] < 0;
// Adjust the beginning for from-the-right indexing, and clip.
int64_t real_begin = begins[i];
if (lower_bounds_mask.count(i))
{
real_begin = (is_reverse ? int64_t(input_shape[i_in] - 1) : 0);
}
else if (real_begin < 0)
{
real_begin += int64_t(input_shape[i_in]);
}
int64_t max_real_begin = int64_t(input_shape[i_in]) - (is_reverse ? 1 : 0);
real_begin = std::max(int64_t(0), std::min(max_real_begin, real_begin));
// Adjust the ending for from-the-right indexing, and clip.
int64_t real_end = ends[i];
if (upper_bounds_mask.count(i))
{
real_end = (is_reverse ? -1 : int64_t(input_shape[i_in]));
}
else if (real_end < 0)
{
real_end += int64_t(input_shape[i_in]);
}
int64_t min_real_end = (is_reverse ? -1 : 0);
real_end = std::max(min_real_end, std::min(int64_t(input_shape[i_in]), real_end));
// Adjust the stride for backwards slicing.
int64_t real_stride = std::abs(strides[i]);
// Adjust for reversal if needed. This isn't quite as simple as swapping begin and
// end, due to striding; we have to adjust the end point to be the _actual_ leftmost
// element, in cases where the stride does not evenly divide the span between begin
// and end.
if (is_reverse)
{
real_end += std::max(int64_t(0), real_begin - real_end - 1) % real_stride;
std::swap(real_begin, real_end);
real_begin++;
real_end++;
p.reverse_axes.insert(i_out);
}
// Compute output dimension.
size_t dim = (real_end <= real_begin
? 0
: size_t(real_end - real_begin - 1) / size_t(real_stride) + 1);
p.reshape_in_shape[i_in] = dim;
p.reshape_out_shape[i_out] = dim;
// Set up the begin/end/stride.
p.begins[i_in] = real_begin;
p.ends[i_in] = real_end;
p.strides[i_in] = real_stride;
i_in++;
i_out++;
}
}
// If there was no ellipsis explicitly given, there is an implicit one at
// the end (it might encompass zero axes, but that's fine).
if (!ellipsis_found)
{
expand_ellipsis();
}
return p;
}
void pass::DynElimination::construct_dyn_reshape()
{
auto data_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{1, 2, 3});
auto begins_arg_label =
make_shared<pattern::op::Label>(element::i64, Shape{3}, pattern::has_class<op::Constant>());
auto ends_arg_label =
make_shared<pattern::op::Label>(element::i64, Shape{3}, pattern::has_class<op::Constant>());
auto strides_arg_label =
make_shared<pattern::op::Label>(element::i64, Shape{3}, pattern::has_class<op::Constant>());
auto dyn_slice_pat = make_shared<op::DynSlice>(data_arg_label,
begins_arg_label,
ends_arg_label,
strides_arg_label,
AxisSet{},
AxisSet{},
AxisSet{},
AxisSet{},
AxisSet{});
auto dyn_slice_callback = [data_arg_label, begins_arg_label, ends_arg_label, strides_arg_label](
pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map();
auto data_arg = pattern_map[data_arg_label];
auto begins_arg = static_pointer_cast<op::Constant>(pattern_map[begins_arg_label]);
auto ends_arg = static_pointer_cast<op::Constant>(pattern_map[ends_arg_label]);
auto strides_arg = static_pointer_cast<op::Constant>(pattern_map[strides_arg_label]);
auto dyn_slice = static_pointer_cast<op::DynSlice>(m.get_match_root());
if (data_arg->get_output_partial_shape(0).is_dynamic() ||
begins_arg->get_element_type() != element::i64 ||
ends_arg->get_element_type() != element::i64 ||
strides_arg->get_element_type() != element::i64)
{
return false;
}
SlicePlan p = make_plan(data_arg->get_output_shape(0),
begins_arg->get_vector<int64_t>(),
ends_arg->get_vector<int64_t>(),
strides_arg->get_vector<int64_t>(),
dyn_slice->get_lower_bounds_mask(),
dyn_slice->get_upper_bounds_mask(),
dyn_slice->get_new_axis(),
dyn_slice->get_shrink_axis(),
dyn_slice->get_ellipsis_mask());
shared_ptr<Node> replacement =
make_shared<op::Slice>(data_arg,
Coordinate(p.begins.begin(), p.begins.end()),
Coordinate(p.ends.begin(), p.ends.end()),
Strides(p.strides.begin(), p.strides.end()));
if (p.reshape_in_shape != p.reshape_out_shape)
{
replacement = make_shared<op::Reshape>(
replacement, ngraph::get_default_order(p.reshape_in_shape), p.reshape_out_shape);
}
if (!p.reverse_axes.empty())
{
replacement = make_shared<op::Reverse>(replacement, p.reverse_axes);
}
replace_node(m.get_match_root(), replacement);
return true;
};
auto dyn_slice_matcher =
make_shared<pattern::Matcher>(dyn_slice_pat, "DynElimination.DynShape");
add_matcher(dyn_slice_matcher, dyn_slice_callback, all_pass_property_off);
}
......@@ -30,6 +30,7 @@ namespace ngraph
private:
void construct_transpose();
void construct_dyn_reshape();
};
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment