Commit 3c469863 authored by Adam Procter's avatar Adam Procter

Implement DynElimination for DynReplaceSlice

parent 85582d0c
......@@ -15,8 +15,10 @@
//*****************************************************************************
#include "dyn_elimination.hpp"
#include "ngraph/op/experimental/dyn_replace_slice.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/slice.hpp"
......@@ -30,7 +32,8 @@ pass::DynElimination::DynElimination()
: GraphRewrite()
{
construct_transpose();
construct_dyn_reshape();
construct_dyn_slice();
construct_dyn_replace_slice();
}
void pass::DynElimination::construct_transpose()
......@@ -315,7 +318,7 @@ static SlicePlan make_plan(const Shape& input_shape,
return p;
}
void pass::DynElimination::construct_dyn_reshape()
void pass::DynElimination::construct_dyn_slice()
{
auto data_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{1, 2, 3});
auto begins_arg_label =
......@@ -387,3 +390,89 @@ void pass::DynElimination::construct_dyn_reshape()
make_shared<pattern::Matcher>(dyn_slice_pat, "DynElimination.DynShape");
add_matcher(dyn_slice_matcher, dyn_slice_callback, all_pass_property_off);
}
void pass::DynElimination::construct_dyn_replace_slice()
{
auto data_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{1, 2, 3});
auto replacement_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{1, 2, 3});
auto begins_arg_label =
make_shared<pattern::op::Label>(element::i64, Shape{3}, pattern::has_class<op::Constant>());
auto ends_arg_label =
make_shared<pattern::op::Label>(element::i64, Shape{3}, pattern::has_class<op::Constant>());
auto strides_arg_label =
make_shared<pattern::op::Label>(element::i64, Shape{3}, pattern::has_class<op::Constant>());
auto dyn_replace_slice_pat = make_shared<op::DynReplaceSlice>(data_arg_label,
replacement_arg_label,
begins_arg_label,
ends_arg_label,
strides_arg_label,
AxisSet{},
AxisSet{},
AxisSet{},
AxisSet{},
AxisSet{});
auto dyn_replace_slice_callback = [data_arg_label,
replacement_arg_label,
begins_arg_label,
ends_arg_label,
strides_arg_label](pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map();
auto data_arg = pattern_map[data_arg_label];
auto replacement_arg = pattern_map[replacement_arg_label];
auto begins_arg = static_pointer_cast<op::Constant>(pattern_map[begins_arg_label]);
auto ends_arg = static_pointer_cast<op::Constant>(pattern_map[ends_arg_label]);
auto strides_arg = static_pointer_cast<op::Constant>(pattern_map[strides_arg_label]);
auto dyn_replace_slice = static_pointer_cast<op::DynReplaceSlice>(m.get_match_root());
if (data_arg->get_output_partial_shape(0).is_dynamic() ||
replacement_arg->get_output_partial_shape(0).is_dynamic() ||
begins_arg->get_element_type() != element::i64 ||
ends_arg->get_element_type() != element::i64 ||
strides_arg->get_element_type() != element::i64)
{
return false;
}
SlicePlan p = make_plan(data_arg->get_output_shape(0),
begins_arg->get_vector<int64_t>(),
ends_arg->get_vector<int64_t>(),
strides_arg->get_vector<int64_t>(),
dyn_replace_slice->get_lower_bounds_mask(),
dyn_replace_slice->get_upper_bounds_mask(),
dyn_replace_slice->get_new_axis(),
dyn_replace_slice->get_shrink_axis(),
dyn_replace_slice->get_ellipsis_mask());
shared_ptr<Node> substitute_replacement_arg = replacement_arg;
if (!p.reverse_axes.empty())
{
substitute_replacement_arg =
make_shared<op::Reverse>(substitute_replacement_arg, p.reverse_axes);
}
if (p.reshape_in_shape != p.reshape_out_shape)
{
substitute_replacement_arg =
make_shared<op::Reshape>(substitute_replacement_arg,
ngraph::get_default_order(p.reshape_out_shape),
p.reshape_in_shape);
}
auto substitute_rsl =
make_shared<op::ReplaceSlice>(data_arg,
substitute_replacement_arg,
Coordinate(p.begins.begin(), p.begins.end()),
Coordinate(p.ends.begin(), p.ends.end()),
Strides(p.strides.begin(), p.strides.end()));
replace_node(m.get_match_root(), substitute_rsl);
return true;
};
auto dyn_replace_slice_matcher =
make_shared<pattern::Matcher>(dyn_replace_slice_pat, "DynElimination.DynReplaceShape");
add_matcher(dyn_replace_slice_matcher, dyn_replace_slice_callback, all_pass_property_off);
}
......@@ -30,7 +30,8 @@ namespace ngraph
private:
void construct_transpose();
void construct_dyn_reshape();
void construct_dyn_slice();
void construct_dyn_replace_slice();
};
}
}
......@@ -131,3 +131,53 @@ TEST(dyn_elimination, slice)
ASSERT_EQ(f->get_results().at(0)->get_element_type(), element::f32);
ASSERT_EQ(f->get_results().at(0)->get_shape(), (Shape{2, 4, 2, 2, 1, 2, 2}));
}
TEST(dyn_elimination, replace_slice)
{
// input has shape [2,4,6,8,2,2,2]
// slice in numpy syntax is [0:,:4,2:6:2,7:3:-2,np.newaxis,...,1]
// slice shape should be [2,4,2,2,1,2,2] (so sayeth numpy!)
Shape shape_in{2, 4, 6, 8, 2, 2, 2};
Shape shape_slice{2, 4, 2, 2, 1, 2, 2};
auto input = make_shared<op::Parameter>(element::f32, shape_in);
auto replacement = make_shared<op::Parameter>(element::f32, shape_slice);
auto constant_lb =
make_shared<op::Constant>(element::i64, Shape{7}, vector<int64_t>{0, 3, 2, 7, 0, 0, 1});
auto constant_ub =
make_shared<op::Constant>(element::i64, Shape{7}, vector<int64_t>{0, 4, 6, 3, 0, 0, 0});
auto constant_strides =
make_shared<op::Constant>(element::i64, Shape{7}, vector<int64_t>{1, 1, 2, -2, 0, 0, 0});
AxisSet lower_bounds_mask{1};
AxisSet upper_bounds_mask{0};
AxisSet new_axis_mask{4};
AxisSet shrink_mask{6};
AxisSet ellipsis_mask{5};
auto rsl = make_shared<op::DynReplaceSlice>(input,
replacement,
constant_lb,
constant_ub,
constant_strides,
lower_bounds_mask,
upper_bounds_mask,
new_axis_mask,
shrink_mask,
ellipsis_mask);
ASSERT_EQ(rsl->get_element_type(), element::f32);
ASSERT_EQ(rsl->get_shape(), (Shape{2, 4, 6, 8, 2, 2, 2}));
auto f = make_shared<Function>(rsl, ParameterVector{input, replacement});
pass::Manager pass_manager;
pass_manager.register_pass<pass::DynElimination>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::DynReplaceSlice>(f), 0);
ASSERT_EQ(count_ops_of_type<op::ReplaceSlice>(f), 1);
ASSERT_EQ(count_ops_of_type<op::Reshape>(f), 1);
ASSERT_EQ(count_ops_of_type<op::Reverse>(f), 1);
ASSERT_EQ(f->get_results().at(0)->get_element_type(), element::f32);
ASSERT_EQ(f->get_results().at(0)->get_shape(), (Shape{2, 4, 6, 8, 2, 2, 2}));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment