Commit f642bc4c authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Robert Kimball

Reshape Sinking (#1701)

* reshape sinking working on mnist_conv

* forgot to add reshape_sinking files

* refactoring of binary case

* Quantize/Dequantize case, fix add case, add assert

* address bob and scott's feedback

* debug

* fix a bug where reshapes are removed too early
parent edc40856
......@@ -89,7 +89,6 @@ namespace ngraph
protected:
/// Throws if the node is invalid.
virtual void validate_and_infer_types();
// Called in constructors during transition
void constructor_validate_and_infer_types();
......@@ -107,6 +106,7 @@ namespace ngraph
virtual void generate_adjoints(autodiff::Adjoints& adjoints, const NodeVector& deltas) {}
public:
virtual void validate_and_infer_types();
// Called after transition
void delayed_validate_and_infer_types();
......
......@@ -110,6 +110,7 @@ set(SRC
pass/cpu_post_layout_optimizations.cpp
pass/cpu_rnn_fusion.cpp
pass/cpu_workspace_insertion.cpp
pass/cpu_reshape_sinking.cpp
)
if (NOT NGRAPH_DEX_ONLY)
......
This diff is collapsed.
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace runtime
{
namespace cpu
{
namespace pass
{
class CPUReshapeSinking : public ngraph::pass::FunctionPass
{
public:
bool run_on_function(std::shared_ptr<ngraph::Function> function) override;
};
}
}
}
}
......@@ -488,3 +488,13 @@ AxisVector ngraph::get_default_order(size_t rank)
std::iota(begin(default_order), end(default_order), 0);
return default_order;
}
AxisVector ngraph::get_permutation_to_default_order(const AxisVector& axis_order)
{
AxisVector out(axis_order.size());
for (size_t i = 0; i < axis_order.size(); i++)
{
out.at(axis_order[i]) = i;
}
return out;
}
......@@ -204,6 +204,8 @@ namespace ngraph
AxisVector get_default_order(size_t rank);
AxisVector get_default_order(const Shape& shape);
AxisVector get_permutation_to_default_order(const AxisVector& axis_order);
/*
* Return type struct for cache_fprop, with the modified fprop and bprop
* functions
......
......@@ -69,7 +69,7 @@ add_subdirectory(files)
add_subdirectory(util)
if(NGRAPH_CPU_ENABLE)
set(SRC ${SRC} backend_performance.cpp cpu_fusion.cpp cpu_test.cpp)
set(SRC ${SRC} backend_performance.cpp cpu_fusion.cpp cpu_test.cpp cpu_reshape_sinking.cpp)
endif()
if(NGRAPH_GPU_ENABLE)
......
//*****************************************************************************
// Copyright 2017-2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <cstdio>
#include <iostream>
#include <list>
#include <memory>
#include "gtest/gtest.h"
#include "ngraph/autodiff/adjoints.hpp"
#include "ngraph/file_util.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/pass/core_fusion.hpp"
#include "ngraph/pass/cse.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/reshape_elimination.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/runtime/cpu/pass/cpu_fusion.hpp"
#include "ngraph/runtime/cpu/pass/cpu_reshape_sinking.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
#include "nlohmann/json.hpp"
#include "util/all_close.hpp"
#include "util/autodiff/backprop_function.hpp"
#include "util/autodiff/numeric_compare.hpp"
#include "util/ndarray.hpp"
#include "util/random.hpp"
#include "util/test_tools.hpp"
using namespace ngraph;
using namespace std;
TEST(cpu_reshape_sinking, edge_splitting)
{
//checks if Reshapes are pushed through op::Abs, but stopped by Sum
Shape shape_nhwc{16, 28, 28, 1};
Shape shape_nchw{16, 1, 28, 28};
auto a = make_shared<op::Parameter>(element::i32, shape_nhwc);
auto reshape = make_shared<op::Reshape>(a, AxisVector{0, 3, 1, 2}, shape_nchw);
auto absn = make_shared<op::Abs>(reshape);
auto absn2 = make_shared<op::Abs>(absn);
auto sum = make_shared<op::Sum>(reshape, AxisSet{0, 1, 2, 3});
auto func = make_shared<Function>(NodeVector{absn2, sum}, op::ParameterVector{a});
pass::Manager pass_manager;
//size_t before_count = count_ops_of_type<op::Reshape>(func);
pass_manager.register_pass<pass::VisualizeTree>("before.pdf");
pass_manager.register_pass<runtime::cpu::pass::CPUReshapeSinking>();
pass_manager.register_pass<pass::ReshapeElimination>();
pass_manager.register_pass<pass::CommonSubexpressionElimination>();
pass_manager.register_pass<pass::VisualizeTree>("after.pdf");
pass_manager.run_passes(func);
ASSERT_EQ(func->get_results().at(1)->get_argument(0), sum);
auto new_reshape =
std::dynamic_pointer_cast<op::Reshape>(func->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_reshape);
ASSERT_EQ(new_reshape->get_shape(), shape_nchw);
}
TEST(cpu_reshape_sinking, mnist_conv)
{
const string json_path = file_util::path_join(SERIALIZED_ZOO, "tf_conv_mnist_nhwc.json");
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
pass::Manager pass_manager;
size_t before_count = count_ops_of_type<op::Reshape>(func);
//pass_manager.register_pass<pass::VisualizeTree>("before.pdf");
pass_manager.register_pass<runtime::cpu::pass::CPUReshapeSinking>();
pass_manager.register_pass<pass::ReshapeElimination>();
pass_manager.register_pass<pass::CommonSubexpressionElimination>();
//pass_manager.register_pass<pass::CoreFusion>();
//pass_manager.register_pass<runtime::cpu::pass::CPUFusion>();
//pass_manager.register_pass<pass::VisualizeTree>("after.pdf");
pass_manager.run_passes(func);
size_t before_after = count_ops_of_type<op::Reshape>(func);
ASSERT_LE(before_after, before_count);
}
[{
"name" : "Function_0",
"ops" : [
{
"cacheable" : false,
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_12",
"op" : "Parameter",
"outputs" : ["Parameter_12_0"],
"shape" : [ 2, 224, 224, 3 ]
},
{
"cacheable" : false,
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_11",
"op" : "Parameter",
"outputs" : ["Parameter_11_0"],
"shape" : [10]
},
{
"cacheable" : false,
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_10",
"op" : "Parameter",
"outputs" : ["Parameter_10_0"],
"shape" : [ 37632, 10 ]
},
{
"cacheable" : false,
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_9",
"op" : "Parameter",
"outputs" : ["Parameter_9_0"],
"shape" : [3]
},
{
"cacheable" : false,
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_8",
"op" : "Parameter",
"outputs" : ["Parameter_8_0"],
"shape" : [3]
},
{
"cacheable" : false,
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_7",
"op" : "Parameter",
"outputs" : ["Parameter_7_0"],
"shape" : [3]
},
{
"cacheable" : false,
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_6",
"op" : "Parameter",
"outputs" : ["Parameter_6_0"],
"shape" : [3]
},
{
"cacheable" : false,
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_5",
"op" : "Parameter",
"outputs" : ["Parameter_5_0"],
"shape" : [ 3, 3, 3, 3 ]
},
{
"cacheable" : false,
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_4",
"op" : "Parameter",
"outputs" : ["Parameter_4_0"],
"shape" : [3]
},
{
"cacheable" : false,
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_3",
"op" : "Parameter",
"outputs" : ["Parameter_3_0"],
"shape" : [3]
},
{
"cacheable" : false,
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_2",
"op" : "Parameter",
"outputs" : ["Parameter_2_0"],
"shape" : [3]
},
{
"cacheable" : false,
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_1",
"op" : "Parameter",
"outputs" : ["Parameter_1_0"],
"shape" : [3]
},
{
"cacheable" : false,
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_0",
"op" : "Parameter",
"outputs" : ["Parameter_0_0"],
"shape" : [ 3, 3, 3, 3 ]
},
{
"input_order" : [ 0, 3, 1, 2 ],
"inputs" : ["Parameter_12"],
"name" : "Reshape_13",
"op" : "Reshape",
"output_shape" : [ 2, 3, 224, 224 ],
"outputs" : ["Reshape_13_0"]
},
{
"axes" : [0],
"inputs" : ["Parameter_11"],
"name" : "Broadcast_36",
"op" : "Broadcast",
"outputs" : ["Broadcast_36_0"],
"shape" : [ 2, 10 ]
},
{
"input_order" : [ 3, 2, 0, 1 ],
"inputs" : ["Parameter_5"],
"name" : "Reshape_22",
"op" : "Reshape",
"output_shape" : [ 3, 3, 3, 3 ],
"outputs" : ["Reshape_22_0"]
},
{
"input_order" : [ 3, 2, 0, 1 ],
"inputs" : ["Parameter_0"],
"name" : "Reshape_14",
"op" : "Reshape",
"output_shape" : [ 3, 3, 3, 3 ],
"outputs" : ["Reshape_14_0"]
},
{
"data_dilation_strides" : [ 1, 1 ],
"inputs" : [ "Reshape_13", "Reshape_14" ],
"name" : "Convolution_15",
"op" : "Convolution",
"outputs" : ["Convolution_15_0"],
"padding_above" : [ 1, 1 ],
"padding_below" : [ 1, 1 ],
"window_dilation_strides" : [ 1, 1 ],
"window_movement_strides" : [ 1, 1 ]
},
{
"input_order" : [ 0, 2, 3, 1 ],
"inputs" : ["Convolution_15"],
"name" : "Reshape_16",
"op" : "Reshape",
"output_shape" : [ 2, 224, 224, 3 ],
"outputs" : ["Reshape_16_0"]
},
{
"input_order" : [ 0, 3, 1, 2 ],
"inputs" : ["Reshape_16"],
"name" : "Reshape_17",
"op" : "Reshape",
"output_shape" : [ 2, 3, 224, 224 ],
"outputs" : ["Reshape_17_0"]
},
{
"eps" : 1.0009999641624745e-05,
"inputs" : [
"Parameter_1", "Parameter_2", "Reshape_17", "Parameter_3",
"Parameter_4"
],
"name" : "BatchNorm_18",
"op" : "BatchNorm",
"outputs" : ["BatchNorm_18_0"],
"training" : false
},
{
"input_order" : [ 0, 2, 3, 1 ],
"inputs" : ["BatchNorm_18"],
"name" : "Reshape_19",
"op" : "Reshape",
"output_shape" : [ 2, 224, 224, 3 ],
"outputs" : ["Reshape_19_0"]
},
{
"inputs" : ["Reshape_19"],
"name" : "Relu_20",
"op" : "Relu",
"outputs" : ["Relu_20_0"]
},
{
"input_order" : [ 0, 3, 1, 2 ],
"inputs" : ["Relu_20"],
"name" : "Reshape_21",
"op" : "Reshape",
"output_shape" : [ 2, 3, 224, 224 ],
"outputs" : ["Reshape_21_0"]
},
{
"data_dilation_strides" : [ 1, 1 ],
"inputs" : [ "Reshape_21", "Reshape_22" ],
"name" : "Convolution_23",
"op" : "Convolution",
"outputs" : ["Convolution_23_0"],
"padding_above" : [ 1, 1 ],
"padding_below" : [ 1, 1 ],
"window_dilation_strides" : [ 1, 1 ],
"window_movement_strides" : [ 1, 1 ]
},
{
"input_order" : [ 0, 2, 3, 1 ],
"inputs" : ["Convolution_23"],
"name" : "Reshape_24",
"op" : "Reshape",
"output_shape" : [ 2, 224, 224, 3 ],
"outputs" : ["Reshape_24_0"]
},
{
"input_order" : [ 0, 3, 1, 2 ],
"inputs" : ["Reshape_24"],
"name" : "Reshape_25",
"op" : "Reshape",
"output_shape" : [ 2, 3, 224, 224 ],
"outputs" : ["Reshape_25_0"]
},
{
"eps" : 1.0009999641624745e-05,
"inputs" : [
"Parameter_6", "Parameter_7", "Reshape_25", "Parameter_8",
"Parameter_9"
],
"name" : "BatchNorm_26",
"op" : "BatchNorm",
"outputs" : ["BatchNorm_26_0"],
"training" : false
},
{
"input_order" : [ 0, 2, 3, 1 ],
"inputs" : ["BatchNorm_26"],
"name" : "Reshape_27",
"op" : "Reshape",
"output_shape" : [ 2, 224, 224, 3 ],
"outputs" : ["Reshape_27_0"]
},
{
"inputs" : [ "Reshape_27", "Parameter_12" ],
"name" : "Add_28",
"op" : "Add",
"outputs" : ["Add_28_0"]
},
{
"inputs" : ["Add_28"],
"name" : "Relu_29",
"op" : "Relu",
"outputs" : ["Relu_29_0"]
},
{
"input_order" : [ 0, 3, 1, 2 ],
"inputs" : ["Relu_29"],
"name" : "Reshape_30",
"op" : "Reshape",
"output_shape" : [ 2, 3, 224, 224 ],
"outputs" : ["Reshape_30_0"]
},
{
"include_padding_in_avg_computation" : false,
"inputs" : ["Reshape_30"],
"name" : "AvgPool_31",
"op" : "AvgPool",
"outputs" : ["AvgPool_31_0"],
"padding_above" : [ 0, 0 ],
"padding_below" : [ 0, 0 ],
"window_movement_strides" : [ 2, 2 ],
"window_shape" : [ 2, 2 ]
},
{
"input_order" : [ 0, 2, 3, 1 ],
"inputs" : ["AvgPool_31"],
"name" : "Reshape_32",
"op" : "Reshape",
"output_shape" : [ 2, 112, 112, 3 ],
"outputs" : ["Reshape_32_0"]
},
{
"input_order" : [ 0, 1, 2, 3 ],
"inputs" : ["Reshape_32"],
"name" : "Reshape_34",
"op" : "Reshape",
"output_shape" : [ 2, 37632 ],
"outputs" : ["Reshape_34_0"]
},
{
"inputs" : [ "Reshape_34", "Parameter_10" ],
"name" : "Dot_35",
"op" : "Dot",
"outputs" : ["Dot_35_0"],
"reduction_axes_count" : 1
},
{
"inputs" : [ "Dot_35", "Broadcast_36" ],
"name" : "Add_37",
"op" : "Add",
"outputs" : ["Add_37_0"]
},
{
"inputs" : ["Add_37"],
"name" : "Result_38",
"op" : "Result",
"outputs" : ["Result_38_0"]
}
],
"parameters" : [
"Parameter_0", "Parameter_1", "Parameter_2", "Parameter_3", "Parameter_4",
"Parameter_5", "Parameter_6", "Parameter_7", "Parameter_8", "Parameter_9",
"Parameter_10", "Parameter_11", "Parameter_12"
],
"result" : ["Result_38"]
}]
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment