Commit 4514faf9 authored by Jimin Ha's avatar Jimin Ha Committed by Scott Cyphers

Move CPU ReshapeSinking to Core pass (#2211)

* Move CPU ReshapeSinking to Core pass

* Modify clang compile error

* Fix for style-apply check
parent 922aaaf8
......@@ -149,6 +149,7 @@ set (SRC
pass/pass_config.cpp
pass/propagate_cacheability.cpp
pass/reshape_elimination.cpp
pass/reshape_sinking.cpp
pass/zero_dim_tensor_elimination.cpp
pass/validate_graph.cpp
pass/visualize_tree.cpp
......
......@@ -14,14 +14,13 @@
// limitations under the License.
//*****************************************************************************
#include "cpu_reshape_sinking.hpp"
#include "reshape_sinking.hpp"
#include <algorithm>
#include <iostream>
#include <numeric>
#include <set>
#include <unordered_set>
#include "cpu_collapse_dims.hpp"
#include "ngraph/descriptor/input.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
......@@ -242,8 +241,7 @@ static void convert_binary_to_default_order(
//For each op type we support we can either combine
//two reshapes by replacing the existing Reshape,
//materialize pending reshapes if they can't be propagated through op
bool ngraph::runtime::cpu::pass::CPUReshapeSinking::run_on_function(
std::shared_ptr<ngraph::Function> f)
bool ngraph::pass::ReshapeSinking::run_on_function(std::shared_ptr<ngraph::Function> f)
{
std::unordered_map<std::shared_ptr<Node>, std::shared_ptr<op::Reshape>> reorders;
NodeVector results;
......
......@@ -20,18 +20,12 @@
namespace ngraph
{
namespace runtime
namespace pass
{
namespace cpu
class ReshapeSinking : public ngraph::pass::FunctionPass
{
namespace pass
{
class CPUReshapeSinking : public ngraph::pass::FunctionPass
{
public:
bool run_on_function(std::shared_ptr<ngraph::Function> function) override;
};
}
}
public:
bool run_on_function(std::shared_ptr<ngraph::Function> function) override;
};
}
}
......@@ -114,7 +114,6 @@ set(SRC
pass/cpu_post_layout_optimizations.cpp
pass/cpu_rnn_fusion.cpp
pass/cpu_workspace_insertion.cpp
pass/cpu_reshape_sinking.cpp
ngraph_version.cpp
)
......
......@@ -132,6 +132,7 @@
#include "ngraph/pass/nop_elimination.hpp"
#include "ngraph/pass/propagate_cacheability.hpp"
#include "ngraph/pass/reshape_elimination.hpp"
#include "ngraph/pass/reshape_sinking.hpp"
#include "ngraph/pass/zero_dim_tensor_elimination.hpp"
#include "ngraph/runtime/aligned_buffer.hpp"
#include "ngraph/runtime/cpu/cpu_backend.hpp"
......@@ -172,7 +173,6 @@
#include "ngraph/runtime/cpu/pass/cpu_mat_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_memory_optimization.hpp"
#include "ngraph/runtime/cpu/pass/cpu_post_layout_optimizations.hpp"
#include "ngraph/runtime/cpu/pass/cpu_reshape_sinking.hpp"
#include "ngraph/runtime/cpu/pass/cpu_rnn_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_workspace_insertion.hpp"
#include "ngraph/runtime/cpu/pass/halide_subgraph_extraction.hpp"
......@@ -1093,7 +1093,7 @@ void runtime::cpu::CPU_ExternalFunction::register_common_passes(ngraph::pass::Ma
REGISTER_KNOBBED_PASS(MultiLayerRNNFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPURnnMatFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPUBatchFusion, true, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(CPUReshapeSinking, false, runtime::cpu::pass);
REGISTER_KNOBBED_PASS(ReshapeSinking, false, ngraph::pass);
REGISTER_KNOBBED_PASS(ReshapeElimination, false, ngraph::pass);
REGISTER_KNOBBED_PASS(CoreFusion, true, ngraph::pass);
REGISTER_KNOBBED_PASS(CPUFusion, true, runtime::cpu::pass);
......
......@@ -51,6 +51,7 @@ set(SRC
pass_memory_layout.cpp
pattern.cpp
reshape_elimination.cpp
reshape_sinking.cpp
serialize.cpp
shape.cpp
tensor.cpp
......@@ -80,7 +81,7 @@ endif()
if (NGRAPH_CPU_ENABLE)
list(APPEND SRC core_fusion.cpp builder_quantization.cpp)
list(APPEND SRC backend_performance.cpp cpu_fusion.cpp cpu_test.cpp cpu_reshape_sinking.cpp cpu_debugger.cpp)
list(APPEND SRC backend_performance.cpp cpu_fusion.cpp cpu_test.cpp cpu_debugger.cpp)
if (NGRAPH_HALIDE)
list(APPEND SRC halide.cpp)
endif()
......
......@@ -33,9 +33,8 @@
#include "ngraph/pass/cse.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/reshape_elimination.hpp"
#include "ngraph/pass/reshape_sinking.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_reshape_sinking.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
#include "nlohmann/json.hpp"
......@@ -49,7 +48,7 @@
using namespace ngraph;
using namespace std;
TEST(cpu_reshape_sinking, edge_splitting)
TEST(reshape_sinking, edge_splitting)
{
//checks if Reshapes are pushed through op::Abs, but stopped by Sum
Shape shape_nhwc{16, 28, 28, 1};
......@@ -63,7 +62,7 @@ TEST(cpu_reshape_sinking, edge_splitting)
pass::Manager pass_manager;
//size_t before_count = count_ops_of_type<op::Reshape>(func);
pass_manager.register_pass<pass::VisualizeTree>("before.pdf");
pass_manager.register_pass<runtime::cpu::pass::CPUReshapeSinking>();
pass_manager.register_pass<pass::ReshapeSinking>();
pass_manager.register_pass<pass::ReshapeElimination>();
pass_manager.register_pass<pass::CommonSubexpressionElimination>();
pass_manager.register_pass<pass::VisualizeTree>("after.pdf");
......@@ -75,7 +74,7 @@ TEST(cpu_reshape_sinking, edge_splitting)
ASSERT_EQ(new_reshape->get_shape(), shape_nchw);
}
TEST(cpu_reshape_sinking, broadcast_swimming)
TEST(reshape_sinking, broadcast_swimming)
{
Shape shape_nchw{1, 32, 536, 536};
Shape shape_nhwc{1, 536, 536, 32};
......@@ -102,7 +101,7 @@ TEST(cpu_reshape_sinking, broadcast_swimming)
auto func = make_shared<Function>(NodeVector{relu}, ParameterVector{bias, input, weights});
pass::Manager pass_manager;
pass_manager.register_pass<runtime::cpu::pass::CPUReshapeSinking>();
pass_manager.register_pass<pass::ReshapeSinking>();
pass_manager.register_pass<pass::ReshapeElimination>();
pass_manager.register_pass<pass::CommonSubexpressionElimination>();
pass_manager.run_passes(func);
......@@ -112,7 +111,7 @@ TEST(cpu_reshape_sinking, broadcast_swimming)
ASSERT_EQ(add->get_argument(1), conv);
}
TEST(cpu_reshape_sinking, mnist_conv)
TEST(reshape_sinking, mnist_conv)
{
const string json_path = file_util::path_join(SERIALIZED_ZOO, "tf_conv_mnist_nhwc.json");
const string json_string = file_util::read_file_to_string(json_path);
......@@ -121,7 +120,7 @@ TEST(cpu_reshape_sinking, mnist_conv)
pass::Manager pass_manager;
size_t before_count = count_ops_of_type<op::Reshape>(func);
//pass_manager.register_pass<pass::VisualizeTree>("before.pdf");
pass_manager.register_pass<runtime::cpu::pass::CPUReshapeSinking>();
pass_manager.register_pass<pass::ReshapeSinking>();
pass_manager.register_pass<pass::ReshapeElimination>();
pass_manager.register_pass<pass::CommonSubexpressionElimination>();
//pass_manager.register_pass<pass::CoreFusion>();
......@@ -131,3 +130,36 @@ TEST(cpu_reshape_sinking, mnist_conv)
size_t before_after = count_ops_of_type<op::Reshape>(func);
ASSERT_LE(before_after, before_count);
}
TEST(reshape_sinking, nasnet_pooladd)
{
Shape input_shape{1, 3, 3, 1};
auto input_type = element::f32;
auto output_type = element::f32;
auto X = make_shared<op::Parameter>(input_type, input_shape);
auto c_weights = op::Constant::create(input_type, Shape{1, 1, 1, 1}, {3});
auto reshape1 = make_shared<op::Reshape>(X, AxisVector{0, 3, 1, 2}, Shape{1, 1, 3, 3});
auto avgpool =
make_shared<op::AvgPool>(reshape1, Shape{1, 1}, Strides{1, 1}, Shape{0, 0}, Shape{0, 0});
auto reshape2 = make_shared<op::Reshape>(avgpool, AxisVector{0, 2, 3, 1}, Shape{1, 3, 3, 1});
auto maxpool =
make_shared<op::MaxPool>(reshape1, Shape{1, 1}, Strides{1, 1}, Shape{0, 0}, Shape{0, 0});
auto reshape3 = make_shared<op::Reshape>(maxpool, AxisVector{0, 2, 3, 1}, Shape{1, 3, 3, 1});
auto const1 = op::Constant::create(input_type, Shape{1, 3, 3, 1}, {3});
auto add1 = make_shared<op::Add>(reshape3, const1);
auto add2 = make_shared<op::Add>(add1, reshape2);
auto func = make_shared<Function>(add2, ParameterVector{X});
pass::Manager pass_manager;
size_t before_count = count_ops_of_type<op::Reshape>(func);
pass_manager.register_pass<pass::VisualizeTree>("before.pdf");
pass_manager.register_pass<pass::ReshapeSinking>();
pass_manager.register_pass<pass::ReshapeElimination>();
pass_manager.register_pass<pass::CommonSubexpressionElimination>();
pass_manager.register_pass<pass::VisualizeTree>("after.pdf");
pass_manager.run_passes(func);
size_t before_after = count_ops_of_type<op::Reshape>(func);
ASSERT_LE(before_after, before_count);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment