Commit b9a77a9d authored by Adam Straw's avatar Adam Straw Committed by Scott Cyphers

Constant folding for Reshapes (#1130)

* adding constant propagation pass

* adding test/constant_propagation.cpp

* template make_constant_reshape function

* code review feedback

* add missing files
parent 9be92aae
......@@ -108,6 +108,7 @@ set (SRC
op/util/unary_elementwise.cpp
pass/assign_placement.cpp
pass/algebraic_simplification.cpp
pass/constant_folding.cpp
pass/cse.cpp
pass/dump_sorted.cpp
pass/get_output_element_elimination.cpp
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "constant_folding.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/reference/reshape.hpp"
using namespace std;
using namespace ngraph;
template <class T>
shared_ptr<op::Constant> make_constant_reshape(shared_ptr<op::Constant> constant,
shared_ptr<op::Reshape> reshape)
{
auto out_shape = reshape->get_shape();
vector<T> out_vec(shape_size(out_shape));
runtime::reference::reshape<T>(constant->get_vector<T>().data(),
out_vec.data(),
constant->get_shape(),
reshape->get_input_order(),
out_shape);
return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec);
}
void ngraph::pass::ConstantFolding::construct_constant_reshape()
{
auto constant_label = make_shared<pattern::op::Label>(
element::f32, Shape{2, 4}, pattern::has_class<op::Constant>());
auto reshape = make_shared<op::Reshape>(constant_label, AxisVector{0, 1}, Shape{2, 4, 1});
auto constant_reshape_callback = [constant_label](pattern::Matcher& m) {
auto pattern_map = m.get_pattern_map();
auto constant_match = dynamic_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto reshape_match = dynamic_pointer_cast<op::Reshape>(m.get_match_root());
auto type = constant_match->get_element_type();
if (type == element::i32)
{
replace_node(m.get_match_root(),
make_constant_reshape<int>(constant_match, reshape_match));
return true;
}
else if (type == element::i8)
{
replace_node(m.get_match_root(),
make_constant_reshape<signed char>(constant_match, reshape_match));
return true;
}
else if (type == element::f32)
{
replace_node(m.get_match_root(),
make_constant_reshape<float>(constant_match, reshape_match));
return true;
}
else if (type == element::f64)
{
replace_node(m.get_match_root(),
make_constant_reshape<double>(constant_match, reshape_match));
return true;
}
return false;
};
auto reshape_matcher = make_shared<pattern::Matcher>(reshape, constant_reshape_callback);
this->add_matcher(reshape_matcher);
}
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#pragma once
#include "ngraph/pass/graph_rewrite.hpp"
namespace ngraph
{
namespace pass
{
class ConstantFolding;
}
}
class ngraph::pass::ConstantFolding : public ngraph::pass::GraphRewrite
{
public:
ConstantFolding()
: GraphRewrite()
{
construct_constant_reshape();
}
private:
void construct_constant_reshape();
};
......@@ -18,6 +18,7 @@ set(SRC
algebraic_simplification.cpp
builder_autobroadcast.cpp
build_graph.cpp
constant_folding.cpp
copy.cpp
core_fusion.cpp
cpio.cpp
......
/*******************************************************************************
* Copyright 2018 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include "ngraph/pass/constant_folding.hpp"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/pass/manager.hpp"
#include "util/test_tools.hpp"
using namespace ngraph;
using namespace std;
TEST(constant_folding, constant_reshape)
{
Shape shape_in{2, 4};
Shape shape_out{2, 4, 1};
vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
auto constant = make_shared<op::Constant>(element::f32, shape_in, values_in);
auto reshape = make_shared<op::Reshape>(constant, AxisVector{0, 1}, shape_out);
auto f = make_shared<Function>(reshape, op::ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Reshape>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<float>();
ASSERT_EQ(values_in, values_out);
}
TEST(constant_folding, constant_reshape_permute)
{
Shape shape_in{2, 4};
Shape shape_out{4, 2};
vector<double> values_in{0, 1, 2, 3, 4, 5, 6, 7};
auto constant = make_shared<op::Constant>(element::f64, shape_in, values_in);
auto reshape = make_shared<op::Reshape>(constant, AxisVector{1, 0}, shape_out);
auto f = make_shared<Function>(reshape, op::ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Reshape>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<double>();
vector<double> values_permute{0, 4, 1, 5, 2, 6, 3, 7};
ASSERT_EQ(values_permute, values_out);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment