Commit 55209b7a authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

CF updates: Slice, DynSlice (#3340)

* Move SlicePlan out of DynElimination, for reuse in ConstantFolding

* Add CF support for Slice

* Add CF for DynSlice

* Add <algorithm> to slice_plan.cpp, to make Windows happy
parent 5366be98
......@@ -478,6 +478,8 @@ set (SRC
shape.hpp
shape_util.cpp
shape_util.hpp
slice_plan.cpp
slice_plan.hpp
specialize_function.cpp
specialize_function.hpp
state/rng_state.cpp
......
This diff is collapsed.
......@@ -45,6 +45,8 @@ public:
PRODUCT,
SUM,
CONCAT,
SLICE,
DYN_SLICE,
DYN_RESHAPE,
TRANSPOSE
};
......@@ -66,6 +68,8 @@ public:
construct_constant_product();
construct_constant_sum();
construct_constant_concat();
construct_constant_slice();
construct_constant_dyn_slice();
construct_constant_dyn_reshape();
construct_constant_transpose();
}
......@@ -94,6 +98,8 @@ public:
case CFTransformations::PRODUCT: construct_constant_product(); break;
case CFTransformations::SUM: construct_constant_sum(); break;
case CFTransformations::CONCAT: construct_constant_concat(); break;
case CFTransformations::SLICE: construct_constant_slice(); break;
case CFTransformations::DYN_SLICE: construct_constant_dyn_slice(); break;
case CFTransformations::DYN_RESHAPE: construct_constant_dyn_reshape(); break;
case CFTransformations::TRANSPOSE: construct_constant_transpose(); break;
}
......@@ -114,6 +120,8 @@ private:
void construct_constant_product();
void construct_constant_sum();
void construct_constant_concat();
void construct_constant_slice();
void construct_constant_dyn_slice();
void construct_constant_dyn_reshape();
void construct_constant_transpose();
......
This diff is collapsed.
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include "ngraph/check.hpp"
#include "ngraph/slice_plan.hpp"
using namespace ngraph;
SlicePlan ngraph::make_slice_plan(const Shape& input_shape,
const std::vector<int64_t>& begins,
const std::vector<int64_t>& ends,
const std::vector<int64_t>& strides,
const AxisSet& lower_bounds_mask,
const AxisSet& upper_bounds_mask,
const AxisSet& new_axis_mask,
const AxisSet& shrink_axis_mask,
const AxisSet& ellipsis_mask)
{
NGRAPH_CHECK(begins.size() == ends.size());
NGRAPH_CHECK(ends.size() == strides.size());
size_t num_slice_indices = begins.size();
size_t num_real_axes = 0;
size_t num_shrink_axes = 0;
size_t num_new_axes = 0;
bool ellipsis_found = false;
// Make a pass over the original slices to make sure there is at most one
// ellipsis, and to count up the number of shrink axes, the number of
// "newaxis"es, and the number of "real" axes (axes that are not newaxis
// and are not the ellipsis).
for (size_t i = 0; i < num_slice_indices; i++)
{
if (ellipsis_mask.count(i))
{
NGRAPH_CHECK(!ellipsis_found);
ellipsis_found = true;
}
else if (new_axis_mask.count(i))
{
num_new_axes++;
}
else
{
if (shrink_axis_mask.count(i))
{
num_shrink_axes++;
}
num_real_axes++;
}
}
NGRAPH_CHECK(num_real_axes <= input_shape.size(),
"num_real_axes=",
num_real_axes,
", input_shape=",
input_shape);
// Figure out how many axes need to be inserted when the ellipsis (which
// may be an implicit ellipsis at the end) is expanded.
size_t ellipsis_size = input_shape.size() - num_real_axes;
// Initialize our slice plan.
SlicePlan p;
p.begins = std::vector<int64_t>(num_real_axes + ellipsis_size);
p.ends = std::vector<int64_t>(num_real_axes + ellipsis_size);
p.strides = std::vector<int64_t>(num_real_axes + ellipsis_size);
p.reshape_in_shape = Shape(num_real_axes + ellipsis_size);
p.reshape_out_shape = Shape(num_new_axes + num_real_axes + ellipsis_size - num_shrink_axes);
p.reverse_axes = AxisSet{};
// Begin a maddeningly delicate loop to desugar the original slice.
//
// * i_in is iterating over the axes of the input shape, which are also the axes of
// p.reshape_in_shape.
// * i_out is iterating over the axes of p.reshape_out_shape
size_t i_in = 0;
size_t i_out = 0;
// If no actual ellipsis exists, there is an "implicit" one at the end,
// which we will handle after the loop. So the logic is wrapped up here,
// allowing it to be used both during and after the loop.
auto expand_ellipsis = [&]() {
for (size_t i = 0; i < ellipsis_size; i++)
{
p.begins[i_in] = 0;
p.ends[i_in] = int64_t(input_shape[i_in]);
p.strides[i_in] = 1;
p.reshape_in_shape[i_in] = input_shape[i_in];
p.reshape_out_shape[i_out] = input_shape[i_in];
i_in++;
i_out++;
}
};
for (size_t i = 0; i < num_slice_indices; i++)
{
// If this is a "newaxis", then reshape_out_shape will have a 1 here,
// but reshape_in_shape will not.
if (new_axis_mask.count(i))
{
p.reshape_out_shape[i_out] = 1;
i_out++;
}
// If this is a "shrunken" axis, then reshape_in_shape will have a 1
// here, but reshape_out_shape will not.
else if (shrink_axis_mask.count(i))
{
int64_t begin = begins[i];
// Note that clipping is not used for "shrunken" axes: an
// out-of-bounds index is an error.
NGRAPH_CHECK(begin >= -(int64_t(input_shape[i_in])) &&
begin < int64_t(input_shape[i_in]));
if (begin < 0)
{
begin += int64_t(input_shape[i_in]);
}
p.begins[i_in] = begin;
p.ends[i_in] = begin + 1;
p.strides[i_in] = 1;
p.reshape_in_shape[i_in] = 1;
i_in++;
}
// If this is the ellipsis, expand it.
else if (ellipsis_mask.count(i))
{
expand_ellipsis();
}
// In other cases, we have a nice, ordinary (begin:end:stride) slice.
// We need to adjust for begin/end being masked, and begin/end/stride
// being negative or out of bounds.
else
{
bool is_reverse = strides[i] < 0;
// Adjust the beginning for from-the-right indexing, and clip.
int64_t real_begin = begins[i];
if (lower_bounds_mask.count(i))
{
real_begin = (is_reverse ? int64_t(input_shape[i_in] - 1) : 0);
}
else if (real_begin < 0)
{
real_begin += int64_t(input_shape[i_in]);
}
int64_t max_real_begin = int64_t(input_shape[i_in]) - (is_reverse ? 1 : 0);
real_begin = std::max(int64_t(0), std::min(max_real_begin, real_begin));
// Adjust the ending for from-the-right indexing, and clip.
int64_t real_end = ends[i];
if (upper_bounds_mask.count(i))
{
real_end = (is_reverse ? -1 : int64_t(input_shape[i_in]));
}
else if (real_end < 0)
{
real_end += int64_t(input_shape[i_in]);
}
int64_t min_real_end = (is_reverse ? -1 : 0);
real_end = std::max(min_real_end, std::min(int64_t(input_shape[i_in]), real_end));
// Ensure stride is not zero, and adjust it for backwards slicing.
NGRAPH_CHECK(strides[i] != 0);
int64_t real_stride = std::abs(strides[i]);
// Adjust for reversal if needed. This isn't quite as simple as swapping begin and
// end, due to striding; we have to adjust the end point to be the _actual_ leftmost
// element, in cases where the stride does not evenly divide the span between begin
// and end.
if (is_reverse)
{
real_end += std::max(int64_t(0), real_begin - real_end - 1) % real_stride;
std::swap(real_begin, real_end);
real_begin++;
real_end++;
p.reverse_axes.insert(i_out);
}
// nGraph's slice op does not like it when end < begin, so we truncate for that case
// here.
if (real_end < real_begin)
{
real_end = real_begin;
}
// Compute output dimension.
size_t dim = (real_end <= real_begin
? 0
: size_t(real_end - real_begin - 1) / size_t(real_stride) + 1);
p.reshape_in_shape[i_in] = dim;
p.reshape_out_shape[i_out] = dim;
// Set up the begin/end/stride.
p.begins[i_in] = real_begin;
p.ends[i_in] = real_end;
p.strides[i_in] = real_stride;
i_in++;
i_out++;
}
}
// If there was no ellipsis explicitly given, there is an implicit one at
// the end (it might encompass zero axes, but that's fine).
if (!ellipsis_found)
{
expand_ellipsis();
}
return p;
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <set>
#include "ngraph/axis_set.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
//
// In various places, like ConstantFolding and DynElimination, it is
// useful to transform DynSlice by converting it to a sequence of ops:
//
// Slice (to do the basic slicing)
// |
// v
// Reshape (non-transposing, to handle shrinks)
// |
// v
// Reverse (to emulate backwards stride)
//
// (The Reshape, Reverse, or both may be omitted if they would just be
// identities.)
//
// A SlicePlan is used to collect parameters for these ops.
//
struct SlicePlan
{
// Parameters for the Slice
std::vector<int64_t> begins;
std::vector<int64_t> ends;
std::vector<int64_t> strides;
// Shapes coming into, and going out of, the Reshape.
Shape reshape_in_shape;
Shape reshape_out_shape;
// Parameters for the Reverse
AxisSet reverse_axes;
};
SlicePlan make_slice_plan(const Shape& input_shape,
const std::vector<int64_t>& begins,
const std::vector<int64_t>& ends,
const std::vector<int64_t>& strides,
const AxisSet& lower_bounds_mask,
const AxisSet& upper_bounds_mask,
const AxisSet& new_axis_mask,
const AxisSet& shrink_axis_mask,
const AxisSet& ellipsis_mask);
}
......@@ -739,6 +739,72 @@ TEST(constant_folding, const_floor)
ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
}
TEST(constant_folding, const_slice)
{
Shape shape_in{16};
vector<int> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
auto constant = make_shared<op::Constant>(element::i32, shape_in, values_in);
auto slice = make_shared<op::Slice>(constant, Coordinate{2}, Coordinate{15}, Strides{3});
auto f = make_shared<Function>(slice, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Slice>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int>();
vector<int> sliced_values{3, 6, 9, 12, 15};
ASSERT_EQ(sliced_values, values_out);
}
TEST(constant_folding, const_dyn_slice)
{
Shape shape_in{16};
vector<int> values_in{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
auto constant_data = make_shared<op::Constant>(element::i32, shape_in, values_in);
vector<int> values_lb{2};
auto constant_lb = make_shared<op::Constant>(element::i64, Shape{1}, values_lb);
vector<int> values_ub{15};
auto constant_ub = make_shared<op::Constant>(element::i64, Shape{1}, values_ub);
vector<int> values_strides{3};
auto constant_strides = make_shared<op::Constant>(element::i64, Shape{1}, values_strides);
auto dyn_slice = make_shared<op::DynSlice>(constant_data,
constant_lb,
constant_ub,
constant_strides,
AxisSet{},
AxisSet{},
AxisSet{},
AxisSet{},
AxisSet{});
auto f = make_shared<Function>(dyn_slice, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::DynSlice>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int>();
vector<int> sliced_values{3, 6, 9, 12, 15};
ASSERT_EQ(sliced_values, values_out);
}
TEST(constant_folding, constant_dyn_reshape)
{
Shape shape_in{2, 4};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment