Commit ecd63cfa authored by Jayaram Bobba's avatar Jayaram Bobba Committed by Scott Cyphers

Reshape elimination optimization (#3016)

* Combine transpose and reshape pattern into a single reshape

* optimize reshapes only if in/out shapes dont match

* Default to svg format for visualizing graphs and provide an env variable to change it
parent 59d1504c
......@@ -170,7 +170,9 @@ void pass::Manager::run_passes(shared_ptr<Function> func, bool transitive)
if (m_visualize)
{
pass::VisualizeTree vt(base_filename);
auto format = std::getenv("NGRAPH_VISUALIZE_TRACING_FORMAT");
auto file_ext = format ? std::string(format) : std::string("svg");
pass::VisualizeTree vt(base_filename + std::string(".") + file_ext);
vt.set_ops_to_details(get_state().get_visualize_tree_ops_map());
vt.run_on_module(f_array);
}
......
......@@ -92,7 +92,22 @@ void pass::ReshapeElimination::construct_reshapex2_pattern()
auto gop = pattern_map[op];
auto r2 = static_pointer_cast<op::Reshape>(m.get_match_root());
auto r1 = static_pointer_cast<op::Reshape>(r2->get_argument(0));
if (gop->get_shape() != m.get_match_root()->get_shape())
{
// First reshape transposes and second reshape only changes shape
// Replace with a transpose that changes shape
if (apply_permutation(gop->get_shape(), r1->get_input_order()) == r2->get_shape() &&
r2->get_input_order() == get_default_order(r1->get_shape()) &&
r1->get_users().size() == 1)
{
replace_node(m.get_match_root(),
make_shared<op::Reshape>(gop, r1->get_input_order(), r2->get_shape()));
return true;
}
else
{
NGRAPH_DEBUG << "Operand shape doesn't match the shape of the second reshape!";
NGRAPH_DEBUG << "gop " << gop->get_name()
......@@ -101,10 +116,9 @@ void pass::ReshapeElimination::construct_reshapex2_pattern()
<< "shape = " << vector_to_string(m.get_match_root()->get_shape());
return false;
}
}
auto r2 = dynamic_pointer_cast<op::Reshape>(m.get_match_root());
auto r1 = dynamic_pointer_cast<op::Reshape>(r2->get_argument(0));
// Check for sequence of reshapes/transposes that cancel out.
auto do_r2 = get_default_order(r1->get_shape());
auto do_r1 = get_default_order(gop->get_shape());
......
......@@ -1194,7 +1194,8 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(
REGISTER_KNOBBED_PASS(BatchFusion, true, ngraph::pass);
REGISTER_KNOBBED_PASS(CPUBatchFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(ReshapeSinking, false, ngraph::pass);
REGISTER_KNOBBED_PASS(ReshapeElimination, false, ngraph::pass);
REGISTER_KNOBBED_PASS(ReshapeElimination, true, ngraph::pass);
REGISTER_KNOBBED_PASS(RecurrentReshapeElimination, false, ngraph::pass);
REGISTER_KNOBBED_PASS_WITH_ARGS(
CoreFusion, true, ngraph::pass, ngraph::pass::FusionType::ALL_FUSIONS);
REGISTER_KNOBBED_PASS_WITH_ARGS(FusedOpDecomposition, true, ngraph::pass, is_supported);
......
......@@ -45,6 +45,10 @@ static void visualize_layout_format(const Node& node, ostream& ss)
{
return;
}
if (auto reshape = dynamic_cast<const op::Reshape*>(&node))
{
ss << "\ninput_order=" << reshape->get_input_order();
}
ss << "\nin="
<< runtime::cpu::mkldnn_utils::get_mkldnn_format_string(
static_cast<mkldnn::memory::format>(in_tvl->get_mkldnn_md().data.format));
......
......@@ -88,6 +88,50 @@ TEST(reshape_elimination, bn_bprop_rewrite)
}
#endif
TEST(reshape_elimination, transpose_reshape_pattern_fuse)
{
auto generate_func = []() {
auto input = make_shared<op::Parameter>(element::f32, Shape{8, 2, 4, 6});
auto transpose = make_shared<op::Reshape>(input, AxisVector{0, 2, 1, 3}, Shape{8, 2, 4, 6});
auto reshape =
make_shared<op::Reshape>(transpose, AxisVector{0, 1, 2, 3}, Shape{8, 4, 2, 6});
return make_shared<Function>(reshape, ParameterVector{input});
};
auto fuse_func = generate_func();
auto nofuse_func = generate_func();
pass::Manager pass_manager;
pass_manager.register_pass<pass::ReshapeElimination>();
pass_manager.run_passes(fuse_func);
ASSERT_TRUE(count_ops_of_type<op::Reshape>(fuse_func) == 1);
ASSERT_TRUE(count_ops_of_type<op::Reshape>(nofuse_func) == 2);
test::Uniform<float> rng(0.0f, 100.0f);
vector<vector<float>> args;
vector<float> tensor_val(shape_size(Shape{8, 2, 4, 6}));
rng.initialize(tensor_val);
args.push_back(tensor_val);
auto baseline_results = execute(fuse_func, args, "INTERPRETER");
auto optimized_results = execute(nofuse_func, args, "INTERPRETER");
EXPECT_TRUE(test::all_close(baseline_results.at(0), optimized_results.at(0)));
}
TEST(reshape_elimination, transpose_reshape_pattern_nofuse)
{
auto input = make_shared<op::Parameter>(element::f32, Shape{8, 2, 4, 6});
auto transpose = make_shared<op::Reshape>(input, AxisVector{0, 2, 1, 3}, Shape{8, 2, 4, 6});
auto reshape = make_shared<op::Reshape>(transpose, AxisVector{2, 1, 0, 3}, Shape{8, 4, 2, 6});
auto f = make_shared<Function>(reshape, ParameterVector{input});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ReshapeElimination>();
pass_manager.run_passes(f);
ASSERT_TRUE(count_ops_of_type<op::Reshape>(f) == 2);
}
TEST(reshape_elimination, dot_transpose_to_dot_w_transpose_args)
{
Shape shape_w{2, 4};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment