Commit 55209b7a authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

CF updates: Slice, DynSlice (#3340)

* Move SlicePlan out of DynElimination, for reuse in ConstantFolding

* Add CF support for Slice

* Add CF for DynSlice

* Add <algorithm> to slice_plan.cpp, to make Windows happy
parent 5366be98
...@@ -478,6 +478,8 @@ set (SRC ...@@ -478,6 +478,8 @@ set (SRC
shape.hpp shape.hpp
shape_util.cpp shape_util.cpp
shape_util.hpp shape_util.hpp
slice_plan.cpp
slice_plan.hpp
specialize_function.cpp specialize_function.cpp
specialize_function.hpp specialize_function.hpp
state/rng_state.cpp state/rng_state.cpp
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include "ngraph/op/divide.hpp" #include "ngraph/op/divide.hpp"
#include "ngraph/op/equal.hpp" #include "ngraph/op/equal.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp" #include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/shape_of.hpp" #include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/experimental/transpose.hpp" #include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/floor.hpp" #include "ngraph/op/floor.hpp"
...@@ -52,6 +53,7 @@ ...@@ -52,6 +53,7 @@
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/op/reverse.hpp" #include "ngraph/op/reverse.hpp"
#include "ngraph/op/sign.hpp" #include "ngraph/op/sign.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/sqrt.hpp" #include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp" #include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp" #include "ngraph/op/sum.hpp"
...@@ -86,9 +88,11 @@ ...@@ -86,9 +88,11 @@
#include "ngraph/runtime/reference/reshape.hpp" #include "ngraph/runtime/reference/reshape.hpp"
#include "ngraph/runtime/reference/reverse.hpp" #include "ngraph/runtime/reference/reverse.hpp"
#include "ngraph/runtime/reference/sign.hpp" #include "ngraph/runtime/reference/sign.hpp"
#include "ngraph/runtime/reference/slice.hpp"
#include "ngraph/runtime/reference/sqrt.hpp" #include "ngraph/runtime/reference/sqrt.hpp"
#include "ngraph/runtime/reference/subtract.hpp" #include "ngraph/runtime/reference/subtract.hpp"
#include "ngraph/runtime/reference/sum.hpp" #include "ngraph/runtime/reference/sum.hpp"
#include "ngraph/slice_plan.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
using namespace std; using namespace std;
...@@ -1858,3 +1862,246 @@ void pass::ConstantFolding::construct_constant_concat() ...@@ -1858,3 +1862,246 @@ void pass::ConstantFolding::construct_constant_concat()
make_shared<pattern::Matcher>(concat_op, "ConstantFolding.ConstantConcat"); make_shared<pattern::Matcher>(concat_op, "ConstantFolding.ConstantConcat");
this->add_matcher(concat_matcher, constant_concat_callback, all_pass_property_off); this->add_matcher(concat_matcher, constant_concat_callback, all_pass_property_off);
} }
template <class T>
shared_ptr<op::Constant> fold_constant_slice(shared_ptr<op::Constant> constant,
shared_ptr<op::Slice> slice)
{
auto out_shape = slice->get_shape();
vector<T> out_vec(shape_size(out_shape));
runtime::reference::slice<T>(constant->get_data_ptr<T>(),
out_vec.data(),
constant->get_shape(),
slice->get_lower_bounds(),
slice->get_upper_bounds(),
slice->get_strides(),
out_shape);
return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec);
}
void pass::ConstantFolding::construct_constant_slice()
{
auto data_label = make_shared<pattern::op::Label>(
element::f32, Shape{2, 3, 4}, pattern::has_class<op::Constant>());
auto slice_op = make_shared<op::Slice>(
data_label, Coordinate{1, 1, 1}, Coordinate{2, 3, 4}, Strides{1, 1, 2});
auto constant_slice_callback = [data_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_slice_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto data_node = static_pointer_cast<op::Constant>(pattern_map[data_label]);
auto slice = static_pointer_cast<op::Slice>(m.get_match_root());
std::shared_ptr<op::Constant> replacement;
switch (slice->get_output_element_type(0).get_type_enum())
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in fold_constant_slice");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_slice");
break;
case element::Type_t::boolean:
replacement = fold_constant_slice<char>(data_node, slice);
break;
case element::Type_t::bf16:
replacement = fold_constant_slice<bfloat16>(data_node, slice);
break;
case element::Type_t::f16:
replacement = fold_constant_slice<float16>(data_node, slice);
break;
case element::Type_t::f32:
replacement = fold_constant_slice<float>(data_node, slice);
break;
case element::Type_t::f64:
replacement = fold_constant_slice<double>(data_node, slice);
break;
case element::Type_t::i8:
replacement = fold_constant_slice<int8_t>(data_node, slice);
break;
case element::Type_t::i16:
replacement = fold_constant_slice<int16_t>(data_node, slice);
break;
case element::Type_t::i32:
replacement = fold_constant_slice<int32_t>(data_node, slice);
break;
case element::Type_t::i64:
replacement = fold_constant_slice<int64_t>(data_node, slice);
break;
case element::Type_t::u8:
replacement = fold_constant_slice<uint8_t>(data_node, slice);
break;
case element::Type_t::u16:
replacement = fold_constant_slice<uint16_t>(data_node, slice);
break;
case element::Type_t::u32:
replacement = fold_constant_slice<uint32_t>(data_node, slice);
break;
case element::Type_t::u64:
replacement = fold_constant_slice<uint64_t>(data_node, slice);
break;
}
replace_node(m.get_match_root(), replacement);
return true;
};
auto slice_matcher = make_shared<pattern::Matcher>(slice_op, "ConstantFolding.ConstantSlice");
this->add_matcher(slice_matcher, constant_slice_callback, all_pass_property_off);
}
template <class T>
shared_ptr<op::Constant> fold_constant_dyn_slice(shared_ptr<op::Constant> data,
shared_ptr<op::Constant> lb,
shared_ptr<op::Constant> ub,
shared_ptr<op::Constant> strides,
shared_ptr<op::DynSlice> slice)
{
SlicePlan plan = make_slice_plan(data->get_shape(),
lb->get_vector<int64_t>(),
ub->get_vector<int64_t>(),
strides->get_vector<int64_t>(),
slice->get_lower_bounds_mask(),
slice->get_upper_bounds_mask(),
slice->get_new_axis(),
slice->get_shrink_axis(),
slice->get_ellipsis_mask());
vector<T> slice_out_vec(shape_size(plan.reshape_in_shape));
runtime::reference::slice<T>(data->get_data_ptr<T>(),
slice_out_vec.data(),
data->get_shape(),
Coordinate(plan.begins.begin(), plan.begins.end()),
Coordinate(plan.ends.begin(), plan.ends.end()),
Strides(plan.strides.begin(), plan.strides.end()),
plan.reshape_in_shape);
vector<T> reshape_out_vec(shape_size(plan.reshape_out_shape));
runtime::reference::reshape<T>(slice_out_vec.data(),
reshape_out_vec.data(),
plan.reshape_in_shape,
get_default_order(plan.reshape_in_shape.size()),
plan.reshape_out_shape);
vector<T> reverse_out_vec(shape_size(plan.reshape_out_shape));
runtime::reference::reverse<T>(reshape_out_vec.data(),
reverse_out_vec.data(),
plan.reshape_out_shape,
plan.reshape_out_shape,
plan.reverse_axes);
return make_shared<op::Constant>(
data->get_element_type(), plan.reshape_out_shape, reverse_out_vec);
}
void pass::ConstantFolding::construct_constant_dyn_slice()
{
auto data_label = make_shared<pattern::op::Label>(
element::f32, Shape{2, 3, 4}, pattern::has_class<op::Constant>());
auto lb_label =
make_shared<pattern::op::Label>(element::i64, Shape{3}, pattern::has_class<op::Constant>());
auto ub_label =
make_shared<pattern::op::Label>(element::i64, Shape{3}, pattern::has_class<op::Constant>());
auto strides_label =
make_shared<pattern::op::Label>(element::i64, Shape{3}, pattern::has_class<op::Constant>());
auto dyn_slice_op = make_shared<op::DynSlice>(data_label,
lb_label,
ub_label,
strides_label,
AxisSet{},
AxisSet{},
AxisSet{},
AxisSet{},
AxisSet{});
auto constant_dyn_slice_callback = [data_label, lb_label, ub_label, strides_label](
pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_dyn_slice_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto data_node = static_pointer_cast<op::Constant>(pattern_map[data_label]);
auto lb_node = static_pointer_cast<op::Constant>(pattern_map[lb_label]);
auto ub_node = static_pointer_cast<op::Constant>(pattern_map[ub_label]);
auto strides_node = static_pointer_cast<op::Constant>(pattern_map[strides_label]);
auto dyn_slice = static_pointer_cast<op::DynSlice>(m.get_match_root());
std::shared_ptr<op::Constant> replacement;
switch (dyn_slice->get_output_element_type(0).get_type_enum())
{
case element::Type_t::undefined:
NGRAPH_CHECK(false, "Encountered 'undefined' element type in fold_constant_dyn_slice");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in fold_constant_dyn_slice");
break;
case element::Type_t::boolean:
replacement =
fold_constant_dyn_slice<char>(data_node, lb_node, ub_node, strides_node, dyn_slice);
break;
case element::Type_t::bf16:
replacement = fold_constant_dyn_slice<bfloat16>(
data_node, lb_node, ub_node, strides_node, dyn_slice);
break;
case element::Type_t::f16:
replacement = fold_constant_dyn_slice<float16>(
data_node, lb_node, ub_node, strides_node, dyn_slice);
break;
case element::Type_t::f32:
replacement = fold_constant_dyn_slice<float>(
data_node, lb_node, ub_node, strides_node, dyn_slice);
break;
case element::Type_t::f64:
replacement = fold_constant_dyn_slice<double>(
data_node, lb_node, ub_node, strides_node, dyn_slice);
break;
case element::Type_t::i8:
replacement = fold_constant_dyn_slice<int8_t>(
data_node, lb_node, ub_node, strides_node, dyn_slice);
break;
case element::Type_t::i16:
replacement = fold_constant_dyn_slice<int16_t>(
data_node, lb_node, ub_node, strides_node, dyn_slice);
break;
case element::Type_t::i32:
replacement = fold_constant_dyn_slice<int32_t>(
data_node, lb_node, ub_node, strides_node, dyn_slice);
break;
case element::Type_t::i64:
replacement = fold_constant_dyn_slice<int64_t>(
data_node, lb_node, ub_node, strides_node, dyn_slice);
break;
case element::Type_t::u8:
replacement = fold_constant_dyn_slice<uint8_t>(
data_node, lb_node, ub_node, strides_node, dyn_slice);
break;
case element::Type_t::u16:
replacement = fold_constant_dyn_slice<uint16_t>(
data_node, lb_node, ub_node, strides_node, dyn_slice);
break;
case element::Type_t::u32:
replacement = fold_constant_dyn_slice<uint32_t>(
data_node, lb_node, ub_node, strides_node, dyn_slice);
break;
case element::Type_t::u64:
replacement = fold_constant_dyn_slice<uint64_t>(
data_node, lb_node, ub_node, strides_node, dyn_slice);
break;
}
replace_node(m.get_match_root(), replacement);
return true;
};
auto dyn_slice_matcher =
make_shared<pattern::Matcher>(dyn_slice_op, "ConstantFolding.ConstantDynSlice");
this->add_matcher(dyn_slice_matcher, constant_dyn_slice_callback, all_pass_property_off);
}
...@@ -45,6 +45,8 @@ public: ...@@ -45,6 +45,8 @@ public:
PRODUCT, PRODUCT,
SUM, SUM,
CONCAT, CONCAT,
SLICE,
DYN_SLICE,
DYN_RESHAPE, DYN_RESHAPE,
TRANSPOSE TRANSPOSE
}; };
...@@ -66,6 +68,8 @@ public: ...@@ -66,6 +68,8 @@ public:
construct_constant_product(); construct_constant_product();
construct_constant_sum(); construct_constant_sum();
construct_constant_concat(); construct_constant_concat();
construct_constant_slice();
construct_constant_dyn_slice();
construct_constant_dyn_reshape(); construct_constant_dyn_reshape();
construct_constant_transpose(); construct_constant_transpose();
} }
...@@ -94,6 +98,8 @@ public: ...@@ -94,6 +98,8 @@ public:
case CFTransformations::PRODUCT: construct_constant_product(); break; case CFTransformations::PRODUCT: construct_constant_product(); break;
case CFTransformations::SUM: construct_constant_sum(); break; case CFTransformations::SUM: construct_constant_sum(); break;
case CFTransformations::CONCAT: construct_constant_concat(); break; case CFTransformations::CONCAT: construct_constant_concat(); break;
case CFTransformations::SLICE: construct_constant_slice(); break;
case CFTransformations::DYN_SLICE: construct_constant_dyn_slice(); break;
case CFTransformations::DYN_RESHAPE: construct_constant_dyn_reshape(); break; case CFTransformations::DYN_RESHAPE: construct_constant_dyn_reshape(); break;
case CFTransformations::TRANSPOSE: construct_constant_transpose(); break; case CFTransformations::TRANSPOSE: construct_constant_transpose(); break;
} }
...@@ -114,6 +120,8 @@ private: ...@@ -114,6 +120,8 @@ private:
void construct_constant_product(); void construct_constant_product();
void construct_constant_sum(); void construct_constant_sum();
void construct_constant_concat(); void construct_constant_concat();
void construct_constant_slice();
void construct_constant_dyn_slice();
void construct_constant_dyn_reshape(); void construct_constant_dyn_reshape();
void construct_constant_transpose(); void construct_constant_transpose();
......
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include "ngraph/op/slice.hpp" #include "ngraph/op/slice.hpp"
#include "ngraph/pattern/matcher.hpp" #include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp" #include "ngraph/pattern/op/label.hpp"
#include "ngraph/slice_plan.hpp"
using namespace std; using namespace std;
using namespace ngraph; using namespace ngraph;
...@@ -138,242 +139,6 @@ void pass::DynElimination::construct_dyn_broadcast() ...@@ -138,242 +139,6 @@ void pass::DynElimination::construct_dyn_broadcast()
add_matcher(dyn_broadcast_matcher, dyn_broadcast_callback, all_pass_property_off); add_matcher(dyn_broadcast_matcher, dyn_broadcast_callback, all_pass_property_off);
} }
//
// We eliminate DynSlice by converting it to a sequence of ops:
//
// Slice (to do the basic slicing)
// |
// v
// Reshape (non-transposing, to handle shrinks)
// |
// v
// Reverse (to emulate backwards stride)
//
// (The Reshape, Reverse, or both may be omitted if they would just be identities.)
//
// A SlicePlan is used to collect parameters for these ops.
//
struct SlicePlan
{
// Parameters for the Slice
std::vector<int64_t> begins;
std::vector<int64_t> ends;
std::vector<int64_t> strides;
// Shapes coming into, and going out of, the Reshape.
Shape reshape_in_shape;
Shape reshape_out_shape;
// Parameters for the Reverse
std::set<size_t> reverse_axes;
};
static SlicePlan make_plan(const Shape& input_shape,
const std::vector<int64_t>& begins,
const std::vector<int64_t>& ends,
const std::vector<int64_t>& strides,
const AxisSet& lower_bounds_mask,
const AxisSet& upper_bounds_mask,
const AxisSet& new_axis_mask,
const AxisSet& shrink_axis_mask,
const AxisSet& ellipsis_mask)
{
NGRAPH_CHECK(begins.size() == ends.size());
NGRAPH_CHECK(ends.size() == strides.size());
size_t num_slice_indices = begins.size();
size_t num_real_axes = 0;
size_t num_shrink_axes = 0;
size_t num_new_axes = 0;
bool ellipsis_found = false;
// Make a pass over the original slices to make sure there is at most one
// ellipsis, and to count up the number of shrink axes, the number of
// "newaxis"es, and the number of "real" axes (axes that are not newaxis
// and are not the ellipsis).
for (size_t i = 0; i < num_slice_indices; i++)
{
if (ellipsis_mask.count(i))
{
NGRAPH_CHECK(!ellipsis_found);
ellipsis_found = true;
}
else if (new_axis_mask.count(i))
{
num_new_axes++;
}
else
{
if (shrink_axis_mask.count(i))
{
num_shrink_axes++;
}
num_real_axes++;
}
}
NGRAPH_CHECK(num_real_axes <= input_shape.size(),
"num_real_axes=",
num_real_axes,
", input_shape=",
input_shape);
// Figure out how many axes need to be inserted when the ellipsis (which
// may be an implicit ellipsis at the end) is expanded.
size_t ellipsis_size = input_shape.size() - num_real_axes;
// Initialize our slice plan.
SlicePlan p;
p.begins = std::vector<int64_t>(num_real_axes + ellipsis_size);
p.ends = std::vector<int64_t>(num_real_axes + ellipsis_size);
p.strides = std::vector<int64_t>(num_real_axes + ellipsis_size);
p.reshape_in_shape = Shape(num_real_axes + ellipsis_size);
p.reshape_out_shape = Shape(num_new_axes + num_real_axes + ellipsis_size - num_shrink_axes);
p.reverse_axes = AxisSet{};
// Begin a maddeningly delicate loop to desugar the original slice.
//
// * i_in is iterating over the axes of the input shape, which are also the axes of
// p.reshape_in_shape.
// * i_out is iterating over the axes of p.reshape_out_shape
size_t i_in = 0;
size_t i_out = 0;
// If no actual ellipsis exists, there is an "implicit" one at the end,
// which we will handle after the loop. So the logic is wrapped up here,
// allowing it to be used both during and after the loop.
auto expand_ellipsis = [&]() {
for (size_t i = 0; i < ellipsis_size; i++)
{
p.begins[i_in] = 0;
p.ends[i_in] = int64_t(input_shape[i_in]);
p.strides[i_in] = 1;
p.reshape_in_shape[i_in] = input_shape[i_in];
p.reshape_out_shape[i_out] = input_shape[i_in];
i_in++;
i_out++;
}
};
for (size_t i = 0; i < num_slice_indices; i++)
{
// If this is a "newaxis", then reshape_out_shape will have a 1 here,
// but reshape_in_shape will not.
if (new_axis_mask.count(i))
{
p.reshape_out_shape[i_out] = 1;
i_out++;
}
// If this is a "shrunken" axis, then reshape_in_shape will have a 1
// here, but reshape_out_shape will not.
else if (shrink_axis_mask.count(i))
{
int64_t begin = begins[i];
// Note that clipping is not used for "shrunken" axes: an
// out-of-bounds index is an error.
NGRAPH_CHECK(begin >= -(int64_t(input_shape[i_in])) &&
begin < int64_t(input_shape[i_in]));
if (begin < 0)
{
begin += int64_t(input_shape[i_in]);
}
p.begins[i_in] = begin;
p.ends[i_in] = begin + 1;
p.strides[i_in] = 1;
p.reshape_in_shape[i_in] = 1;
i_in++;
}
// If this is the ellipsis, expand it.
else if (ellipsis_mask.count(i))
{
expand_ellipsis();
}
// In other cases, we have a nice, ordinary (begin:end:stride) slice.
// We need to adjust for begin/end being masked, and begin/end/stride
// being negative or out of bounds.
else
{
bool is_reverse = strides[i] < 0;
// Adjust the beginning for from-the-right indexing, and clip.
int64_t real_begin = begins[i];
if (lower_bounds_mask.count(i))
{
real_begin = (is_reverse ? int64_t(input_shape[i_in] - 1) : 0);
}
else if (real_begin < 0)
{
real_begin += int64_t(input_shape[i_in]);
}
int64_t max_real_begin = int64_t(input_shape[i_in]) - (is_reverse ? 1 : 0);
real_begin = std::max(int64_t(0), std::min(max_real_begin, real_begin));
// Adjust the ending for from-the-right indexing, and clip.
int64_t real_end = ends[i];
if (upper_bounds_mask.count(i))
{
real_end = (is_reverse ? -1 : int64_t(input_shape[i_in]));
}
else if (real_end < 0)
{
real_end += int64_t(input_shape[i_in]);
}
int64_t min_real_end = (is_reverse ? -1 : 0);
real_end = std::max(min_real_end, std::min(int64_t(input_shape[i_in]), real_end));
// Ensure stride is not zero, and adjust it for backwards slicing.
NGRAPH_CHECK(strides[i] != 0);
int64_t real_stride = std::abs(strides[i]);
// Adjust for reversal if needed. This isn't quite as simple as swapping begin and
// end, due to striding; we have to adjust the end point to be the _actual_ leftmost
// element, in cases where the stride does not evenly divide the span between begin
// and end.
if (is_reverse)
{
real_end += std::max(int64_t(0), real_begin - real_end - 1) % real_stride;
std::swap(real_begin, real_end);
real_begin++;
real_end++;
p.reverse_axes.insert(i_out);
}
// nGraph's slice op does not like it when end < begin, so we truncate for that case
// here.
if (real_end < real_begin)
{
real_end = real_begin;
}
// Compute output dimension.
size_t dim = (real_end <= real_begin
? 0
: size_t(real_end - real_begin - 1) / size_t(real_stride) + 1);
p.reshape_in_shape[i_in] = dim;
p.reshape_out_shape[i_out] = dim;
// Set up the begin/end/stride.
p.begins[i_in] = real_begin;
p.ends[i_in] = real_end;
p.strides[i_in] = real_stride;
i_in++;
i_out++;
}
}
// If there was no ellipsis explicitly given, there is an implicit one at
// the end (it might encompass zero axes, but that's fine).
if (!ellipsis_found)
{
expand_ellipsis();
}
return p;
}
void pass::DynElimination::construct_dyn_slice() void pass::DynElimination::construct_dyn_slice()
{ {
auto data_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{1, 2, 3}); auto data_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{1, 2, 3});
...@@ -411,15 +176,15 @@ void pass::DynElimination::construct_dyn_slice() ...@@ -411,15 +176,15 @@ void pass::DynElimination::construct_dyn_slice()
return false; return false;
} }
SlicePlan p = make_plan(data_arg->get_output_shape(0), SlicePlan p = make_slice_plan(data_arg->get_output_shape(0),
begins_arg->get_vector<int64_t>(), begins_arg->get_vector<int64_t>(),
ends_arg->get_vector<int64_t>(), ends_arg->get_vector<int64_t>(),
strides_arg->get_vector<int64_t>(), strides_arg->get_vector<int64_t>(),
dyn_slice->get_lower_bounds_mask(), dyn_slice->get_lower_bounds_mask(),
dyn_slice->get_upper_bounds_mask(), dyn_slice->get_upper_bounds_mask(),
dyn_slice->get_new_axis(), dyn_slice->get_new_axis(),
dyn_slice->get_shrink_axis(), dyn_slice->get_shrink_axis(),
dyn_slice->get_ellipsis_mask()); dyn_slice->get_ellipsis_mask());
shared_ptr<Node> replacement = shared_ptr<Node> replacement =
make_shared<op::Slice>(data_arg, make_shared<op::Slice>(data_arg,
...@@ -491,15 +256,15 @@ void pass::DynElimination::construct_dyn_replace_slice() ...@@ -491,15 +256,15 @@ void pass::DynElimination::construct_dyn_replace_slice()
return false; return false;
} }
SlicePlan p = make_plan(data_arg->get_output_shape(0), SlicePlan p = make_slice_plan(data_arg->get_output_shape(0),
begins_arg->get_vector<int64_t>(), begins_arg->get_vector<int64_t>(),
ends_arg->get_vector<int64_t>(), ends_arg->get_vector<int64_t>(),
strides_arg->get_vector<int64_t>(), strides_arg->get_vector<int64_t>(),
dyn_replace_slice->get_lower_bounds_mask(), dyn_replace_slice->get_lower_bounds_mask(),
dyn_replace_slice->get_upper_bounds_mask(), dyn_replace_slice->get_upper_bounds_mask(),
dyn_replace_slice->get_new_axis(), dyn_replace_slice->get_new_axis(),
dyn_replace_slice->get_shrink_axis(), dyn_replace_slice->get_shrink_axis(),
dyn_replace_slice->get_ellipsis_mask()); dyn_replace_slice->get_ellipsis_mask());
shared_ptr<Node> substitute_replacement_arg = replacement_arg; shared_ptr<Node> substitute_replacement_arg = replacement_arg;
......
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include "ngraph/check.hpp"
#include "ngraph/slice_plan.hpp"
using namespace ngraph;
SlicePlan ngraph::make_slice_plan(const Shape& input_shape,
const std::vector<int64_t>& begins,
const std::vector<int64_t>& ends,
const std::vector<int64_t>& strides,
const AxisSet& lower_bounds_mask,
const AxisSet& upper_bounds_mask,
const AxisSet& new_axis_mask,
const AxisSet& shrink_axis_mask,
const AxisSet& ellipsis_mask)
{
NGRAPH_CHECK(begins.size() == ends.size());
NGRAPH_CHECK(ends.size() == strides.size());
size_t num_slice_indices = begins.size();
size_t num_real_axes = 0;
size_t num_shrink_axes = 0;
size_t num_new_axes = 0;
bool ellipsis_found = false;
// Make a pass over the original slices to make sure there is at most one
// ellipsis, and to count up the number of shrink axes, the number of
// "newaxis"es, and the number of "real" axes (axes that are not newaxis
// and are not the ellipsis).
for (size_t i = 0; i < num_slice_indices; i++)
{
if (ellipsis_mask.count(i))
{
NGRAPH_CHECK(!ellipsis_found);
ellipsis_found = true;
}
else if (new_axis_mask.count(i))
{
num_new_axes++;
}
else
{
if (shrink_axis_mask.count(i))
{
num_shrink_axes++;
}
num_real_axes++;
}
}
NGRAPH_CHECK(num_real_axes <= input_shape.size(),
"num_real_axes=",
num_real_axes,
", input_shape=",
input_shape);
// Figure out how many axes need to be inserted when the ellipsis (which
// may be an implicit ellipsis at the end) is expanded.
size_t ellipsis_size = input_shape.size() - num_real_axes;
// Initialize our slice plan.
SlicePlan p;
p.begins = std::vector<int64_t>(num_real_axes + ellipsis_size);
p.ends = std::vector<int64_t>(num_real_axes + ellipsis_size);
p.strides = std::vector<int64_t>(num_real_axes + ellipsis_size);
p.reshape_in_shape = Shape(num_real_axes + ellipsis_size);
p.reshape_out_shape = Shape(num_new_axes + num_real_axes + ellipsis_size - num_shrink_axes);
p.reverse_axes = AxisSet{};
// Begin a maddeningly delicate loop to desugar the original slice.
//
// * i_in is iterating over the axes of the input shape, which are also the axes of
// p.reshape_in_shape.
// * i_out is iterating over the axes of p.reshape_out_shape
size_t i_in = 0;
size_t i_out = 0;
// If no actual ellipsis exists, there is an "implicit" one at the end,
// which we will handle after the loop. So the logic is wrapped up here,
// allowing it to be used both during and after the loop.
auto expand_ellipsis = [&]() {
for (size_t i = 0; i < ellipsis_size; i++)
{
p.begins[i_in] = 0;
p.ends[i_in] = int64_t(input_shape[i_in]);
p.strides[i_in] = 1;
p.reshape_in_shape[i_in] = input_shape[i_in];
p.reshape_out_shape[i_out] = input_shape[i_in];
i_in++;
i_out++;
}
};
for (size_t i = 0; i < num_slice_indices; i++)
{
// If this is a "newaxis", then reshape_out_shape will have a 1 here,
// but reshape_in_shape will not.
if (new_axis_mask.count(i))
{
p.reshape_out_shape[i_out] = 1;
i_out++;
}
// If this is a "shrunken" axis, then reshape_in_shape will have a 1
// here, but reshape_out_shape will not.
else if (shrink_axis_mask.count(i))
{
int64_t begin = begins[i];
// Note that clipping is not used for "shrunken" axes: an
// out-of-bounds index is an error.
NGRAPH_CHECK(begin >= -(int64_t(input_shape[i_in])) &&
begin < int64_t(input_shape[i_in]));
if (begin < 0)
{
begin += int64_t(input_shape[i_in]);
}
p.begins[i_in] = begin;
p.ends[i_in] = begin + 1;
p.strides[i_in] = 1;
p.reshape_in_shape[i_in] = 1;
i_in++;
}
// If this is the ellipsis, expand it.
else if (ellipsis_mask.count(i))
{
expand_ellipsis();
}
// In other cases, we have a nice, ordinary (begin:end:stride) slice.
// We need to adjust for begin/end being masked, and begin/end/stride
// being negative or out of bounds.
else
{
bool is_reverse = strides[i] < 0;
// Adjust the beginning for from-the-right indexing, and clip.
int64_t real_begin = begins[i];
if (lower_bounds_mask.count(i))
{
real_begin = (is_reverse ? int64_t(input_shape[i_in] - 1) : 0);
}
else if (real_begin < 0)
{
real_begin += int64_t(input_shape[i_in]);
}
int64_t max_real_begin = int64_t(input_shape[i_in]) - (is_reverse ? 1 : 0);
real_begin = std::max(int64_t(0), std::min(max_real_begin, real_begin));
// Adjust the ending for from-the-right indexing, and clip.
int64_t real_end = ends[i];
if (upper_bounds_mask.count(i))
{
real_end = (is_reverse ? -1 : int64_t(input_shape[i_in]));
}
else if (real_end < 0)
{
real_end += int64_t(input_shape[i_in]);
}
int64_t min_real_end = (is_reverse ? -1 : 0);
real_end = std::max(min_real_end, std::min(int64_t(input_shape[i_in]), real_end));
// Ensure stride is not zero, and adjust it for backwards slicing.
NGRAPH_CHECK(strides[i] != 0);
int64_t real_stride = std::abs(strides[i]);
// Adjust for reversal if needed. This isn't quite as simple as swapping begin and
// end, due to striding; we have to adjust the end point to be the _actual_ leftmost
// element, in cases where the stride does not evenly divide the span between begin
// and end.
if (is_reverse)
{
real_end += std::max(int64_t(0), real_begin - real_end - 1) % real_stride;
std::swap(real_begin, real_end);
real_begin++;
real_end++;
p.reverse_axes.insert(i_out);
}
// nGraph's slice op does not like it when end < begin, so we truncate for that case
// here.
if (real_end < real_begin)
{
real_end = real_begin;
}
// Compute output dimension.
size_t dim = (real_end <= real_begin
? 0
: size_t(real_end - real_begin - 1) / size_t(real_stride) + 1);
p.reshape_in_shape[i_in] = dim;
p.reshape_out_shape[i_out] = dim;
// Set up the begin/end/stride.
p.begins[i_in] = real_begin;
p.ends[i_in] = real_end;
p.strides[i_in] = real_stride;
i_in++;
i_out++;
}
}
// If there was no ellipsis explicitly given, there is an implicit one at
// the end (it might encompass zero axes, but that's fine).
if (!ellipsis_found)
{
expand_ellipsis();
}
return p;
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <set>
#include "ngraph/axis_set.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
//
// In various places, like ConstantFolding and DynElimination, it is
// useful to transform DynSlice by converting it to a sequence of ops:
//
// Slice (to do the basic slicing)
// |
// v
// Reshape (non-transposing, to handle shrinks)
// |
// v
// Reverse (to emulate backwards stride)
//
// (The Reshape, Reverse, or both may be omitted if they would just be
// identities.)
//
// A SlicePlan is used to collect parameters for these ops.
//
struct SlicePlan
{
// Parameters for the Slice
std::vector<int64_t> begins;
std::vector<int64_t> ends;
std::vector<int64_t> strides;
// Shapes coming into, and going out of, the Reshape.
Shape reshape_in_shape;
Shape reshape_out_shape;
// Parameters for the Reverse
AxisSet reverse_axes;
};
SlicePlan make_slice_plan(const Shape& input_shape,
const std::vector<int64_t>& begins,
const std::vector<int64_t>& ends,
const std::vector<int64_t>& strides,
const AxisSet& lower_bounds_mask,
const AxisSet& upper_bounds_mask,
const AxisSet& new_axis_mask,
const AxisSet& shrink_axis_mask,
const AxisSet& ellipsis_mask);
}
...@@ -739,6 +739,72 @@ TEST(constant_folding, const_floor) ...@@ -739,6 +739,72 @@ TEST(constant_folding, const_floor)
ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS)); ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
} }
TEST(constant_folding, const_slice)
{
Shape shape_in{16};
vector<int> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
auto constant = make_shared<op::Constant>(element::i32, shape_in, values_in);
auto slice = make_shared<op::Slice>(constant, Coordinate{2}, Coordinate{15}, Strides{3});
auto f = make_shared<Function>(slice, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Slice>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int>();
vector<int> sliced_values{3, 6, 9, 12, 15};
ASSERT_EQ(sliced_values, values_out);
}
TEST(constant_folding, const_dyn_slice)
{
Shape shape_in{16};
vector<int> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
auto constant_data = make_shared<op::Constant>(element::i32, shape_in, values_in);
vector<int> values_lb{2};
auto constant_lb = make_shared<op::Constant>(element::i64, Shape{1}, values_lb);
vector<int> values_ub{15};
auto constant_ub = make_shared<op::Constant>(element::i64, Shape{1}, values_ub);
vector<int> values_strides{3};
auto constant_strides = make_shared<op::Constant>(element::i64, Shape{1}, values_strides);
auto dyn_slice = make_shared<op::DynSlice>(constant_data,
constant_lb,
constant_ub,
constant_strides,
AxisSet{},
AxisSet{},
AxisSet{},
AxisSet{},
AxisSet{});
auto f = make_shared<Function>(dyn_slice, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::DynSlice>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int>();
vector<int> sliced_values{3, 6, 9, 12, 15};
ASSERT_EQ(sliced_values, values_out);
}
TEST(constant_folding, constant_dyn_reshape) TEST(constant_folding, constant_dyn_reshape)
{ {
Shape shape_in{2, 4}; Shape shape_in{2, 4};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment