Commit f5930d37 authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

Reshape Transformations + Simplification pass (#427)

* simplification pass

* serializer change to test models

* some small test fixes

* addressing Scott's feedback

* missed one nn

* formatting fixes

* simplification -> reshape_elimination
parent 0bf21af9
......@@ -85,6 +85,7 @@ set (SRC
pass/memory_layout.cpp
pass/memory_visualize.cpp
pass/pass.cpp
pass/reshape_elimination.cpp
pass/visualize_tree.cpp
pattern/matcher.cpp
runtime/aligned_buffer.cpp
......
// ----------------------------------------------------------------------------
// Copyright 2018 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include "reshape_elimination.hpp"
#include <algorithm>
#include <iostream>
#include <unordered_set>
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ops/add.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/broadcast.hpp"
#include "ngraph/ops/dot.hpp"
#include "ngraph/ops/parameter.hpp"
#include "ngraph/ops/reshape.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/util.hpp"
template <typename T>
static std::vector<T> apply_permutation(std::vector<T> input, ngraph::AxisVector order)
{
if (input.size() != order.size())
{
throw "input and order sizes don't match!";
}
std::vector<T> output(input.size());
for (size_t i = 0; i < order.size(); i++)
{
output[i] = input.at(order.at(i));
}
return output;
}
void ngraph::pass::ReshapeElimination::construct_identity_reshape_pattern()
{
Shape shape_op{3};
Shape shape_r1{1, 3};
auto op = std::make_shared<pattern::op::Label>(element::f32, shape_op);
auto reshape1 = std::make_shared<op::Reshape>(op, AxisVector{0}, shape_r1);
auto callback = [op](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_identity_reshape_pattern against node = "
<< m.match_root()->get_name();
auto pattern_map = m.get_pattern_map();
std::shared_ptr<ngraph::Node> nn;
auto gop = pattern_map[op];
auto r1 = std::dynamic_pointer_cast<op::Reshape>(m.match_root());
if (r1->get_shape() != gop->get_shape())
{
NGRAPH_DEBUG << "Not a no-op; Shapes are different!";
return nn;
}
Shape do_r1(r1->get_shape().size());
std::iota(begin(do_r1), end(do_r1), 0);
if (do_r1 != r1->get_input_order())
{
NGRAPH_DEBUG << "Not a no-op; Not in default input order!";
return nn;
}
return gop;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape1, callback);
this->add_matcher(m);
}
void ngraph::pass::ReshapeElimination::construct_reshapex2_pattern()
{
Shape shape_op{3};
Shape shape_r1{1, 3};
auto op = std::make_shared<pattern::op::Label>(element::f32, shape_op);
auto reshape1 = std::make_shared<op::Reshape>(op, AxisVector{0}, shape_r1);
auto reshape2 = std::make_shared<op::Reshape>(reshape1, AxisVector{0, 1}, shape_op);
auto callback = [op](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for construct_reshapex2_pattern against node = "
<< m.match_root()->get_name();
auto pattern_map = m.get_pattern_map();
std::shared_ptr<ngraph::Node> nn;
auto gop = pattern_map[op];
if (gop->get_shape() != m.match_root()->get_shape())
{
NGRAPH_DEBUG << "Operand shape doesn't match the shape of the second reshape!";
NGRAPH_DEBUG << "gop " << gop->get_name()
<< "shape = " << vector_to_string(gop->get_shape());
NGRAPH_DEBUG << "match_root " << m.match_root()->get_name()
<< "shape = " << vector_to_string(m.match_root()->get_shape());
return nn;
}
auto r2 = std::dynamic_pointer_cast<op::Reshape>(m.match_root());
auto r1 = std::dynamic_pointer_cast<op::Reshape>(r2->get_input_op(0));
Shape do_r2(r1->get_shape().size());
std::iota(begin(do_r2), end(do_r2), 0);
Shape do_r1(gop->get_shape().size());
std::iota(begin(do_r1), end(do_r1), 0);
NGRAPH_DEBUG << "r1's i/o = " << vector_to_string(r1->get_input_order())
<< "do_r1 = " << vector_to_string(do_r1);
NGRAPH_DEBUG << "r2's i/o = " << vector_to_string(r2->get_input_order())
<< "do_r2 = " << vector_to_string(do_r2);
if (r1->get_input_order() == do_r1 && r2->get_input_order() == do_r2)
{
NGRAPH_DEBUG << "Two reshapes were removed!";
return gop;
}
auto perm1 = apply_permutation(do_r1, r1->get_input_order());
auto perm2 = apply_permutation(perm1, r2->get_input_order());
if (perm2 == do_r1)
{
NGRAPH_DEBUG << "Two transposes were removed!";
return gop;
}
return nn;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape2, callback);
this->add_matcher(m);
}
// ----------------------------------------------------------------------------
// Copyright 2018 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
namespace ngraph
{
namespace pass
{
class ReshapeElimination;
}
}
class ngraph::pass::ReshapeElimination : public ngraph::pass::GraphRewrite
{
public:
ReshapeElimination()
: GraphRewrite()
{
construct_identity_reshape_pattern();
construct_reshapex2_pattern();
}
private:
void construct_identity_reshape_pattern();
void construct_reshapex2_pattern();
};
......@@ -62,6 +62,7 @@
#include "ngraph/ops/sin.hpp"
#include "ngraph/ops/sinh.hpp"
#include "ngraph/ops/slice.hpp"
#include "ngraph/ops/sqrt.hpp"
#include "ngraph/ops/subtract.hpp"
#include "ngraph/ops/sum.hpp"
#include "ngraph/ops/tan.hpp"
......@@ -660,6 +661,10 @@ static shared_ptr<ngraph::Function>
auto strides = node_js.at("strides").get<vector<size_t>>();
node = make_shared<op::Slice>(args[0], lower_bounds, upper_bounds, strides);
}
else if (node_op == "Sqrt")
{
node = make_shared<op::Sqrt>(args[0]);
}
else if (node_op == "Subtract")
{
node = make_shared<op::Subtract>(args[0], args[1]);
......@@ -956,6 +961,9 @@ static json write(const Node& n)
node["upper_bounds"] = tmp->get_upper_bounds();
node["strides"] = tmp->get_strides();
}
else if (node_op == "Sqrt")
{
}
else if (node_op == "Subtract")
{
}
......
......@@ -41,6 +41,7 @@ set (SRC
serialize.cpp
pattern.cpp
shape.cpp
reshape_elimination.cpp
tensor.cpp
type_prop.cpp
util/autodiff/backprop_function.cpp
......
This diff is collapsed.
[{
"name" : "Function_0",
"ops" : [
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_36",
"op" : "Parameter",
"outputs" : ["Parameter_36_0"],
"shape" : [ 2, 3 ]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_4",
"op" : "Parameter",
"outputs" : ["Parameter_4_0"],
"shape" : [3]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_3",
"op" : "Parameter",
"outputs" : ["Parameter_3_0"],
"shape" : [3]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_2",
"op" : "Parameter",
"outputs" : ["Parameter_2_0"],
"shape" : [3]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_1",
"op" : "Parameter",
"outputs" : ["Parameter_1_0"],
"shape" : [3]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_0",
"op" : "Parameter",
"outputs" : ["Parameter_0_0"],
"shape" : [ 2, 3 ]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_18",
"op" : "Constant",
"outputs" : ["Constant_18_0"],
"shape" : [],
"value" : ["0.001"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_12",
"op" : "Constant",
"outputs" : ["Constant_12_0"],
"shape" : [3],
"value" : [ "2", "2", "2" ]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_6",
"op" : "Constant",
"outputs" : ["Constant_6_0"],
"shape" : [3],
"value" : [ "2", "2", "2" ]
},
{
"input_order" : [0],
"inputs" : ["Parameter_2"],
"name" : "Reshape_29",
"op" : "Reshape",
"output_shape" : [ 1, 3 ],
"outputs" : ["Reshape_29_0"]
},
{
"input_order" : [0],
"inputs" : ["Parameter_1"],
"name" : "Reshape_28",
"op" : "Reshape",
"output_shape" : [ 1, 3 ],
"outputs" : ["Reshape_28_0"]
},
{
"inputs" : [ "Parameter_0", "Parameter_0" ],
"name" : "Multiply_10",
"op" : "Multiply",
"outputs" : ["Multiply_10_0"]
},
{
"inputs" : ["Parameter_0"],
"name" : "Sum_5",
"op" : "Sum",
"outputs" : ["Sum_5_0"],
"reduction_axes" : [0]
},
{
"inputs" : ["Parameter_0"],
"name" : "Sum_9",
"op" : "Sum",
"outputs" : ["Sum_9_0"],
"reduction_axes" : [0]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_18"],
"name" : "Broadcast_19",
"op" : "Broadcast",
"outputs" : ["Broadcast_19_0"],
"shape" : [ 2, 3 ]
},
{
"input_order" : [ 0, 1 ],
"inputs" : ["Reshape_29"],
"name" : "Reshape_33",
"op" : "Reshape",
"output_shape" : [3],
"outputs" : ["Reshape_33_0"]
},
{
"input_order" : [ 0, 1 ],
"inputs" : ["Reshape_28"],
"name" : "Reshape_30",
"op" : "Reshape",
"output_shape" : [3],
"outputs" : ["Reshape_30_0"]
},
{
"inputs" : ["Multiply_10"],
"name" : "Sum_11",
"op" : "Sum",
"outputs" : ["Sum_11_0"],
"reduction_axes" : [0]
},
{
"inputs" : [ "Sum_5", "Constant_6" ],
"name" : "Divide_7",
"op" : "Divide",
"outputs" : ["Divide_7_0"]
},
{
"inputs" : [ "Sum_9", "Sum_9" ],
"name" : "Multiply_13",
"op" : "Multiply",
"outputs" : ["Multiply_13_0"]
},
{
"axes" : [0],
"inputs" : ["Reshape_33"],
"name" : "Broadcast_34",
"op" : "Broadcast",
"outputs" : ["Broadcast_34_0"],
"shape" : [ 2, 3 ]
},
{
"axes" : [0],
"inputs" : ["Reshape_30"],
"name" : "Broadcast_31",
"op" : "Broadcast",
"outputs" : ["Broadcast_31_0"],
"shape" : [ 2, 3 ]
},
{
"input_order" : [0],
"inputs" : ["Divide_7"],
"name" : "Reshape_8",
"op" : "Reshape",
"output_shape" : [ 1, 3 ],
"outputs" : ["Reshape_8_0"]
},
{
"inputs" : [ "Multiply_13", "Constant_12" ],
"name" : "Divide_14",
"op" : "Divide",
"outputs" : ["Divide_14_0"]
},
{
"input_order" : [ 0, 1 ],
"inputs" : ["Reshape_8"],
"name" : "Reshape_24",
"op" : "Reshape",
"output_shape" : [3],
"outputs" : ["Reshape_24_0"]
},
{
"inputs" : [ "Sum_11", "Divide_14" ],
"name" : "Subtract_15",
"op" : "Subtract",
"outputs" : ["Subtract_15_0"]
},
{
"axes" : [0],
"inputs" : ["Reshape_24"],
"name" : "Broadcast_25",
"op" : "Broadcast",
"outputs" : ["Broadcast_25_0"],
"shape" : [ 2, 3 ]
},
{
"inputs" : [ "Subtract_15", "Constant_12" ],
"name" : "Divide_16",
"op" : "Divide",
"outputs" : ["Divide_16_0"]
},
{
"inputs" : [ "Parameter_0", "Broadcast_25" ],
"name" : "Subtract_26",
"op" : "Subtract",
"outputs" : ["Subtract_26_0"]
},
{
"input_order" : [0],
"inputs" : ["Divide_16"],
"name" : "Reshape_17",
"op" : "Reshape",
"output_shape" : [ 1, 3 ],
"outputs" : ["Reshape_17_0"]
},
{
"input_order" : [ 0, 1 ],
"inputs" : ["Reshape_17"],
"name" : "Reshape_20",
"op" : "Reshape",
"output_shape" : [3],
"outputs" : ["Reshape_20_0"]
},
{
"axes" : [0],
"inputs" : ["Reshape_20"],
"name" : "Broadcast_21",
"op" : "Broadcast",
"outputs" : ["Broadcast_21_0"],
"shape" : [ 2, 3 ]
},
{
"inputs" : [ "Broadcast_21", "Broadcast_19" ],
"name" : "Add_22",
"op" : "Add",
"outputs" : ["Add_22_0"]
},
{
"inputs" : ["Add_22"],
"name" : "Sqrt_23",
"op" : "Sqrt",
"outputs" : ["Sqrt_23_0"]
},
{
"inputs" : [ "Subtract_26", "Sqrt_23" ],
"name" : "Divide_27",
"op" : "Divide",
"outputs" : ["Divide_27_0"]
},
{
"inputs" : [ "Divide_27", "Broadcast_31" ],
"name" : "Multiply_32",
"op" : "Multiply",
"outputs" : ["Multiply_32_0"]
},
{
"inputs" : [ "Multiply_32", "Broadcast_34" ],
"name" : "Add_35",
"op" : "Add",
"outputs" : ["Add_35_0"]
},
{
"inputs" : [ "Add_35", "Parameter_36" ],
"name" : "Multiply_37",
"op" : "Multiply",
"outputs" : ["Multiply_37_0"]
}
],
"parameters" : [
"Parameter_0", "Parameter_1", "Parameter_2", "Parameter_3", "Parameter_4",
"Parameter_36"
],
"result" : ["Multiply_37"]
}]
[{
"name" : "Function_1",
"ops" : [
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_21",
"op" : "Parameter",
"outputs" : ["Parameter_21_0"],
"shape" : [10]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_20",
"op" : "Parameter",
"outputs" : ["Parameter_20_0"],
"shape" : [ 10, 64 ]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_12",
"op" : "Parameter",
"outputs" : ["Parameter_12_0"],
"shape" : [64]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_11",
"op" : "Parameter",
"outputs" : ["Parameter_11_0"],
"shape" : [ 64, 128 ]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_3",
"op" : "Parameter",
"outputs" : ["Parameter_3_0"],
"shape" : [128]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_2",
"op" : "Parameter",
"outputs" : ["Parameter_2_0"],
"shape" : [ 128, 784 ]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_0",
"op" : "Parameter",
"outputs" : ["Parameter_0_0"],
"shape" : [ 64, 1, 28, 28 ]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Parameter_26",
"op" : "Parameter",
"outputs" : ["Parameter_26_0"],
"shape" : [ 64, 10 ]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_17",
"op" : "Constant",
"outputs" : ["Constant_17_0"],
"shape" : [],
"value" : ["0"]
},
{
"element_type" : "float",
"inputs" : [],
"name" : "Constant_8",
"op" : "Constant",
"outputs" : ["Constant_8_0"],
"shape" : [],
"value" : ["0"]
},
{
"input_order" : [ 1, 0 ],
"inputs" : ["Parameter_20"],
"name" : "Reshape_22",
"op" : "Reshape",
"output_shape" : [ 64, 10 ],
"outputs" : ["Reshape_22_0"]
},
{
"axes" : [0],
"inputs" : ["Parameter_12"],
"name" : "Broadcast_15",
"op" : "Broadcast",
"outputs" : ["Broadcast_15_0"],
"shape" : [ 64, 64 ]
},
{
"input_order" : [ 1, 0 ],
"inputs" : ["Parameter_11"],
"name" : "Reshape_13",
"op" : "Reshape",
"output_shape" : [ 128, 64 ],
"outputs" : ["Reshape_13_0"]
},
{
"axes" : [0],
"inputs" : ["Parameter_3"],
"name" : "Broadcast_6",
"op" : "Broadcast",
"outputs" : ["Broadcast_6_0"],
"shape" : [ 64, 128 ]
},
{
"input_order" : [ 1, 0 ],
"inputs" : ["Parameter_2"],
"name" : "Reshape_4",
"op" : "Reshape",
"output_shape" : [ 784, 128 ],
"outputs" : ["Reshape_4_0"]
},
{
"input_order" : [ 0, 1, 2, 3 ],
"inputs" : ["Parameter_0"],
"name" : "Reshape_1",
"op" : "Reshape",
"output_shape" : [ 64, 784 ],
"outputs" : ["Reshape_1_0"]
},
{
"inputs" : ["Parameter_26"],
"name" : "Sum_27",
"op" : "Sum",
"outputs" : ["Sum_27_0"],
"reduction_axes" : [0]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_17"],
"name" : "Broadcast_18",
"op" : "Broadcast",
"outputs" : ["Broadcast_18_0"],
"shape" : [ 64, 64 ]
},
{
"axes" : [ 0, 1 ],
"inputs" : ["Constant_8"],
"name" : "Broadcast_9",
"op" : "Broadcast",
"outputs" : ["Broadcast_9_0"],
"shape" : [ 64, 128 ]
},
{
"input_order" : [ 1, 0 ],
"inputs" : ["Reshape_22"],
"name" : "Reshape_28",
"op" : "Reshape",
"output_shape" : [ 10, 64 ],
"outputs" : ["Reshape_28_0"]
},
{
"input_order" : [ 1, 0 ],
"inputs" : ["Reshape_13"],
"name" : "Reshape_42",
"op" : "Reshape",
"output_shape" : [ 64, 128 ],
"outputs" : ["Reshape_42_0"]
},
{
"input_order" : [ 1, 0 ],
"inputs" : ["Reshape_4"],
"name" : "Reshape_56",
"op" : "Reshape",
"output_shape" : [ 128, 784 ],
"outputs" : ["Reshape_56_0"]
},
{
"inputs" : [ "Reshape_1", "Reshape_4" ],
"name" : "Dot_5",
"op" : "Dot",
"outputs" : ["Dot_5_0"]
},
{
"input_order" : [ 1, 0 ],
"inputs" : ["Reshape_1"],
"name" : "Reshape_58",
"op" : "Reshape",
"output_shape" : [ 784, 64 ],
"outputs" : ["Reshape_58_0"]
},
{
"inputs" : [ "Parameter_26", "Reshape_28" ],
"name" : "Dot_29",
"op" : "Dot",
"outputs" : ["Dot_29_0"]
},
{
"inputs" : [ "Dot_5", "Broadcast_6" ],
"name" : "Add_7",
"op" : "Add",
"outputs" : ["Add_7_0"]
},
{
"inputs" : [ "Add_7", "Broadcast_9" ],
"name" : "Maximum_10",
"op" : "Maximum",
"outputs" : ["Maximum_10_0"]
},
{
"inputs" : [ "Add_7", "Broadcast_9" ],
"name" : "Greater_48",
"op" : "Greater",
"outputs" : ["Greater_48_0"]
},
{
"inputs" : [ "Maximum_10", "Reshape_13" ],
"name" : "Dot_14",
"op" : "Dot",
"outputs" : ["Dot_14_0"]
},
{
"input_order" : [ 1, 0 ],
"inputs" : ["Maximum_10"],
"name" : "Reshape_44",
"op" : "Reshape",
"output_shape" : [ 128, 64 ],
"outputs" : ["Reshape_44_0"]
},
{
"inputs" : ["Greater_48"],
"name" : "Convert_49",
"op" : "Convert",
"outputs" : ["Convert_49_0"],
"target_type" : "float"
},
{
"inputs" : [ "Dot_14", "Broadcast_15" ],
"name" : "Add_16",
"op" : "Add",
"outputs" : ["Add_16_0"]
},
{
"inputs" : [ "Add_16", "Broadcast_18" ],
"name" : "Maximum_19",
"op" : "Maximum",
"outputs" : ["Maximum_19_0"]
},
{
"inputs" : [ "Add_16", "Broadcast_18" ],
"name" : "Greater_34",
"op" : "Greater",
"outputs" : ["Greater_34_0"]
},
{
"input_order" : [ 1, 0 ],
"inputs" : ["Maximum_19"],
"name" : "Reshape_30",
"op" : "Reshape",
"output_shape" : [ 64, 64 ],
"outputs" : ["Reshape_30_0"]
},
{
"inputs" : ["Greater_34"],
"name" : "Convert_35",
"op" : "Convert",
"outputs" : ["Convert_35_0"],
"target_type" : "float"
},
{
"inputs" : [ "Reshape_30", "Parameter_26" ],
"name" : "Dot_31",
"op" : "Dot",
"outputs" : ["Dot_31_0"]
},
{
"inputs" : [ "Dot_29", "Convert_35" ],
"name" : "Multiply_36",
"op" : "Multiply",
"outputs" : ["Multiply_36_0"]
},
{
"input_order" : [ 0, 1 ],
"inputs" : ["Dot_31"],
"name" : "Reshape_32",
"op" : "Reshape",
"output_shape" : [ 64, 10 ],
"outputs" : ["Reshape_32_0"]
},
{
"inputs" : ["Multiply_36"],
"name" : "Sum_41",
"op" : "Sum",
"outputs" : ["Sum_41_0"],
"reduction_axes" : [0]
},
{
"inputs" : [ "Multiply_36", "Reshape_42" ],
"name" : "Dot_43",
"op" : "Dot",
"outputs" : ["Dot_43_0"]
},
{
"inputs" : [ "Reshape_44", "Multiply_36" ],
"name" : "Dot_45",
"op" : "Dot",
"outputs" : ["Dot_45_0"]
},
{
"input_order" : [ 1, 0 ],
"inputs" : ["Reshape_32"],
"name" : "Reshape_33",
"op" : "Reshape",
"output_shape" : [ 10, 64 ],
"outputs" : ["Reshape_33_0"]
},
{
"inputs" : [ "Dot_43", "Convert_49" ],
"name" : "Multiply_50",
"op" : "Multiply",
"outputs" : ["Multiply_50_0"]
},
{
"input_order" : [ 0, 1 ],
"inputs" : ["Dot_45"],
"name" : "Reshape_46",
"op" : "Reshape",
"output_shape" : [ 128, 64 ],
"outputs" : ["Reshape_46_0"]
},
{
"inputs" : ["Multiply_50"],
"name" : "Sum_55",
"op" : "Sum",
"outputs" : ["Sum_55_0"],
"reduction_axes" : [0]
},
{
"inputs" : [ "Multiply_50", "Reshape_56" ],
"name" : "Dot_57",
"op" : "Dot",
"outputs" : ["Dot_57_0"]
},
{
"inputs" : [ "Reshape_58", "Multiply_50" ],
"name" : "Dot_59",
"op" : "Dot",
"outputs" : ["Dot_59_0"]
},
{
"input_order" : [ 1, 0 ],
"inputs" : ["Reshape_46"],
"name" : "Reshape_47",
"op" : "Reshape",
"output_shape" : [ 64, 128 ],
"outputs" : ["Reshape_47_0"]
},
{
"input_order" : [ 0, 1 ],
"inputs" : ["Dot_57"],
"name" : "Reshape_62",
"op" : "Reshape",
"output_shape" : [ 64, 1, 28, 28 ],
"outputs" : ["Reshape_62_0"]
},
{
"input_order" : [ 0, 1 ],
"inputs" : ["Dot_59"],
"name" : "Reshape_60",
"op" : "Reshape",
"output_shape" : [ 784, 128 ],
"outputs" : ["Reshape_60_0"]
},
{
"input_order" : [ 1, 0 ],
"inputs" : ["Reshape_60"],
"name" : "Reshape_61",
"op" : "Reshape",
"output_shape" : [ 128, 784 ],
"outputs" : ["Reshape_61_0"]
}
],
"parameters" : [
"Parameter_26", "Parameter_0", "Parameter_2", "Parameter_3", "Parameter_11",
"Parameter_12", "Parameter_20", "Parameter_21"
],
"result" : [
"Reshape_62", "Reshape_61", "Sum_55", "Reshape_47", "Sum_41", "Reshape_33",
"Sum_27"
]
}]
// ----------------------------------------------------------------------------
// Copyright 2018 Nervana Systems Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// ----------------------------------------------------------------------------
#include <algorithm>
#include <cstdio>
#include <iostream>
#include <list>
#include <memory>
#include "gtest/gtest.h"
#include "ngraph/file_util.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/json.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/ops/sum.hpp"
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/reshape_elimination.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/any.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/serializer.hpp"
#include "ngraph/util.hpp"
#include "util/matcher.hpp"
#include "util/test_tools.hpp"
using namespace ngraph;
using namespace std;
TEST(reshape_elimination, remove_reshape)
{
pass::Manager pass_manager;
pass_manager.register_pass<pass::ReshapeElimination>();
const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/bn_fprop.json");
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
size_t count_before = count_ops_of_type<op::Reshape>(func);
pass_manager.run_passes(func);
size_t count_after = count_ops_of_type<op::Reshape>(func);
ASSERT_TRUE(count_after < count_before);
}
TEST(reshape_elimination, remove_tranpose)
{
pass::Manager pass_manager;
pass_manager.register_pass<pass::ReshapeElimination>();
const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/tranpose.json");
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
size_t count_before = count_ops_of_type<op::Reshape>(func);
pass_manager.run_passes(func);
size_t count_after = count_ops_of_type<op::Reshape>(func);
ASSERT_TRUE(count_after < count_before);
}
TEST(reshape_elimination, bn_bprop_rewrite)
{
pass::Manager pass_manager;
pass_manager.register_pass<pass::ReshapeElimination>();
const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/bn_bprop.json");
const string json_string = file_util::read_file_to_string(json_path);
stringstream ss(json_string);
shared_ptr<Function> func = ngraph::deserialize(ss);
size_t count_before = count_ops_of_type<op::Reshape>(func);
pass_manager.run_passes(func);
size_t count_after = count_ops_of_type<op::Reshape>(func);
ASSERT_TRUE(count_after < count_before);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment