Commit 5e38aa6b authored by Adam Procter's avatar Adam Procter

Add DynElimination for DynReshape

parent d0f03eec
......@@ -14,9 +14,12 @@
// limitations under the License.
//*****************************************************************************
#include <numeric>
#include "dyn_elimination.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/experimental/dyn_broadcast.hpp"
#include "ngraph/op/experimental/dyn_reshape.hpp"
#include "ngraph/op/experimental/dyn_slice.hpp"
#include "ngraph/op/experimental/transpose.hpp"
#include "ngraph/op/reshape.hpp"
......@@ -33,6 +36,7 @@ pass::DynElimination::DynElimination()
{
construct_transpose();
construct_broadcast();
construct_dyn_slice();
construct_dyn_reshape();
}
......@@ -365,7 +369,7 @@ static SlicePlan make_plan(const Shape& input_shape,
return p;
}
void pass::DynElimination::construct_dyn_reshape()
void pass::DynElimination::construct_dyn_slice()
{
auto data_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{1, 2, 3});
auto begins_arg_label =
......@@ -434,6 +438,49 @@ void pass::DynElimination::construct_dyn_reshape()
};
auto dyn_slice_matcher =
make_shared<pattern::Matcher>(dyn_slice_pat, "DynElimination.DynShape");
make_shared<pattern::Matcher>(dyn_slice_pat, "DynElimination.DynSlice");
add_matcher(dyn_slice_matcher, dyn_slice_callback, all_pass_property_off);
}
void pass::DynElimination::construct_dyn_reshape()
{
auto data_arg_label = make_shared<pattern::op::Label>(element::f32, Shape{1, 2, 3});
auto shape_arg_label =
make_shared<pattern::op::Label>(element::i64, Shape{3}, pattern::has_class<op::Constant>());
auto dyn_reshape = make_shared<op::DynReshape>(data_arg_label, shape_arg_label);
auto dyn_reshape_callback = [data_arg_label, shape_arg_label](pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map();
auto data_arg = pattern_map[data_arg_label];
auto shape_arg = static_pointer_cast<op::Constant>(pattern_map[shape_arg_label]);
auto dyn_reshape_node = static_pointer_cast<op::DynReshape>(m.get_match_root());
// TODO(amprocte): Can't handle the case where data rank is dynamic even if we know the
// output shape, because static Reshape requries an axis permutation (here an identity) to
// be given. See if we can come up with a workaround.
if (data_arg->get_output_partial_shape(0).rank().is_dynamic())
{
return false;
}
if (dyn_reshape_node->get_output_partial_shape(0).is_dynamic())
{
return false;
}
auto& result_shape = dyn_reshape_node->get_output_shape(0);
AxisVector perm(size_t(data_arg->get_output_partial_shape(0).rank()));
std::iota(perm.begin(), perm.end(), 0);
auto replacement = std::make_shared<op::Reshape>(data_arg, perm, result_shape);
replace_node(dyn_reshape_node, replacement);
return true;
};
auto dyn_reshape_matcher =
make_shared<pattern::Matcher>(dyn_reshape, "DynElimination.DynReshape");
add_matcher(dyn_reshape_matcher, dyn_reshape_callback, all_pass_property_off);
}
......@@ -31,6 +31,7 @@ namespace ngraph
private:
void construct_transpose();
void construct_broadcast();
void construct_dyn_slice();
void construct_dyn_reshape();
};
}
......
......@@ -131,3 +131,27 @@ TEST(dyn_elimination, slice)
ASSERT_EQ(f->get_results().at(0)->get_element_type(), element::f32);
ASSERT_EQ(f->get_results().at(0)->get_shape(), (Shape{2, 4, 2, 2, 1, 2, 2}));
}
TEST(dyn_elimination, reshape)
{
auto input_arg = make_shared<op::Parameter>(element::f32, Shape{2, 4, 6, 8});
auto shape_arg = make_shared<op::Constant>(element::i64, Shape{3}, vector<int64_t>{0, 6, -1});
auto r = make_shared<op::DynReshape>(input_arg, shape_arg, true);
ASSERT_EQ(r->get_element_type(), element::f32);
ASSERT_EQ(r->get_shape(), (Shape{2, 6, 32}));
auto f = make_shared<Function>(r, ParameterVector{input_arg});
pass::Manager pass_manager;
pass_manager.register_pass<pass::DynElimination>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::DynReshape>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Reshape>(f), 1);
ASSERT_EQ(f->get_results().at(0)->get_element_type(), element::f32);
ASSERT_EQ(f->get_results().at(0)->get_shape(), (Shape{2, 6, 32}));
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment