Commit 1ad0d723 authored by Amy Zhuang's avatar Amy Zhuang Committed by Scott Cyphers

Add constant folding for Squeeze and Unsqueeze. (#3794)

* Add constant folding for Squeeze and Unsqueeze.

* Address PR feedback.
parent a24a44e2
......@@ -424,6 +424,8 @@ set (SRC
pass/constant_folding_reshape.cpp
pass/constant_folding_reverse.cpp
pass/constant_folding_select.cpp
pass/constant_folding_squeeze.cpp
pass/constant_folding_unsqueeze.cpp
pass/constant_folding_shape_of.cpp
pass/constant_folding_slice.cpp
pass/constant_folding_transpose.cpp
......
......@@ -53,7 +53,9 @@ public:
DYN_RESHAPE,
TRANSPOSE,
RANGE,
SELECT
SELECT,
SQUEEZE,
UNSQUEEZE
};
ConstantFolding(const ngraph::BuildNodeExecutorMap& cfmap = ngraph::BuildNodeExecutorMap())
......@@ -81,6 +83,8 @@ public:
construct_constant_transpose();
construct_constant_range();
construct_constant_select();
construct_constant_squeeze();
construct_constant_unsqueeze();
}
// this allows to specify the order in which matchers will be run
......@@ -119,6 +123,8 @@ public:
case CFTransformations::TRANSPOSE: construct_constant_transpose(); break;
case CFTransformations::RANGE: construct_constant_range(); break;
case CFTransformations::SELECT: construct_constant_select(); break;
case CFTransformations::SQUEEZE: construct_constant_squeeze(); break;
case CFTransformations::UNSQUEEZE: construct_constant_unsqueeze(); break;
}
}
}
......@@ -145,6 +151,8 @@ private:
void construct_constant_transpose();
void construct_constant_range();
void construct_constant_select();
void construct_constant_squeeze();
void construct_constant_unsqueeze();
ngraph::BuildNodeExecutorMap m_cfmap;
};
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "constant_folding.hpp"
#include "ngraph/op/fused/squeeze.hpp"
using namespace std;
using namespace ngraph;
template <class T>
shared_ptr<op::Constant> fold_constant_squeeze(shared_ptr<op::Constant> constant,
shared_ptr<op::Squeeze> squeeze)
{
auto out_shape = squeeze->get_shape();
vector<T> out_vec(shape_size(out_shape));
out_vec = constant->get_vector<T>();
return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec);
}
void pass::ConstantFolding::construct_constant_squeeze()
{
auto constant_data_label = make_shared<pattern::op::Label>(
element::f32, Shape{2, 1, 4}, pattern::has_class<op::Constant>());
Shape axes_shape{1};
vector<int64_t> values_axes{1};
auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
auto squeeze = make_shared<op::Squeeze>(constant_data_label, constant_axes);
auto constant_squeeze_callback = [&, constant_data_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_squeeze_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_data_label]);
auto squeeze_match = static_pointer_cast<op::Squeeze>(m.get_match_root());
NGRAPH_CHECK(revalidate_and_ensure_static(squeeze_match));
std::shared_ptr<Node> replacement;
auto type = constant_match->get_element_type();
switch (type)
{
case element::Type_t::undefined:
NGRAPH_CHECK(false,
"Encountered 'undefined' element type in constant_squeeze_callback");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false, "Encountered 'dynamic' element type in constant_squeeze_callback");
break;
case element::Type_t::boolean:
replacement = fold_constant_squeeze<char>(constant_match, squeeze_match);
break;
case element::Type_t::bf16:
replacement = fold_constant_squeeze<bfloat16>(constant_match, squeeze_match);
break;
case element::Type_t::f16:
replacement = fold_constant_squeeze<float16>(constant_match, squeeze_match);
break;
case element::Type_t::f32:
replacement = fold_constant_squeeze<float>(constant_match, squeeze_match);
break;
case element::Type_t::f64:
replacement = fold_constant_squeeze<double>(constant_match, squeeze_match);
break;
case element::Type_t::i8:
replacement = fold_constant_squeeze<int8_t>(constant_match, squeeze_match);
break;
case element::Type_t::i16:
replacement = fold_constant_squeeze<int16_t>(constant_match, squeeze_match);
break;
case element::Type_t::i32:
replacement = fold_constant_squeeze<int32_t>(constant_match, squeeze_match);
break;
case element::Type_t::i64:
replacement = fold_constant_squeeze<int64_t>(constant_match, squeeze_match);
break;
case element::Type_t::u8:
replacement = fold_constant_squeeze<uint8_t>(constant_match, squeeze_match);
break;
case element::Type_t::u16:
replacement = fold_constant_squeeze<uint16_t>(constant_match, squeeze_match);
break;
case element::Type_t::u32:
replacement = fold_constant_squeeze<uint32_t>(constant_match, squeeze_match);
break;
case element::Type_t::u64:
replacement = fold_constant_squeeze<uint64_t>(constant_match, squeeze_match);
break;
}
replace_node(m.get_match_root(), replacement);
return true;
};
auto squeeze_matcher =
make_shared<pattern::Matcher>(squeeze, "ConstantFolding.ConstantSqueeze");
this->add_matcher(
squeeze_matcher, constant_squeeze_callback, PassProperty::CHANGE_DYNAMIC_STATE);
}
//*****************************************************************************
// Copyright 2017-2019 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "constant_folding.hpp"
#include "ngraph/op/fused/unsqueeze.hpp"
using namespace std;
using namespace ngraph;
template <class T>
shared_ptr<op::Constant> fold_constant_unsqueeze(shared_ptr<op::Constant> constant,
shared_ptr<op::Unsqueeze> unsqueeze)
{
auto out_shape = unsqueeze->get_shape();
vector<T> out_vec(shape_size(out_shape));
out_vec = constant->get_vector<T>();
return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec);
}
void pass::ConstantFolding::construct_constant_unsqueeze()
{
auto constant_data_label = make_shared<pattern::op::Label>(
element::f32, Shape{2, 4}, pattern::has_class<op::Constant>());
Shape axes_shape{1};
vector<int64_t> values_axes{1};
auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
auto unsqueeze = make_shared<op::Unsqueeze>(constant_data_label, constant_axes);
auto constant_unsqueeze_callback = [&, constant_data_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_unsqueeze_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto constant_match = static_pointer_cast<op::Constant>(pattern_map[constant_data_label]);
auto unsqueeze_match = static_pointer_cast<op::Unsqueeze>(m.get_match_root());
NGRAPH_CHECK(revalidate_and_ensure_static(unsqueeze_match));
std::shared_ptr<Node> replacement;
auto type = constant_match->get_element_type();
switch (type)
{
case element::Type_t::undefined:
NGRAPH_CHECK(false,
"Encountered 'undefined' element type in constant_unsqueeze_callback");
break;
case element::Type_t::dynamic:
NGRAPH_CHECK(false,
"Encountered 'dynamic' element type in constant_unsqueeze_callback");
break;
case element::Type_t::boolean:
replacement = fold_constant_unsqueeze<char>(constant_match, unsqueeze_match);
break;
case element::Type_t::bf16:
replacement = fold_constant_unsqueeze<bfloat16>(constant_match, unsqueeze_match);
break;
case element::Type_t::f16:
replacement = fold_constant_unsqueeze<float16>(constant_match, unsqueeze_match);
break;
case element::Type_t::f32:
replacement = fold_constant_unsqueeze<float>(constant_match, unsqueeze_match);
break;
case element::Type_t::f64:
replacement = fold_constant_unsqueeze<double>(constant_match, unsqueeze_match);
break;
case element::Type_t::i8:
replacement = fold_constant_unsqueeze<int8_t>(constant_match, unsqueeze_match);
break;
case element::Type_t::i16:
replacement = fold_constant_unsqueeze<int16_t>(constant_match, unsqueeze_match);
break;
case element::Type_t::i32:
replacement = fold_constant_unsqueeze<int32_t>(constant_match, unsqueeze_match);
break;
case element::Type_t::i64:
replacement = fold_constant_unsqueeze<int64_t>(constant_match, unsqueeze_match);
break;
case element::Type_t::u8:
replacement = fold_constant_unsqueeze<uint8_t>(constant_match, unsqueeze_match);
break;
case element::Type_t::u16:
replacement = fold_constant_unsqueeze<uint16_t>(constant_match, unsqueeze_match);
break;
case element::Type_t::u32:
replacement = fold_constant_unsqueeze<uint32_t>(constant_match, unsqueeze_match);
break;
case element::Type_t::u64:
replacement = fold_constant_unsqueeze<uint64_t>(constant_match, unsqueeze_match);
break;
}
replace_node(m.get_match_root(), replacement);
return true;
};
auto unsqueeze_matcher =
make_shared<pattern::Matcher>(unsqueeze, "ConstantFolding.ConstantUnsqueeze");
this->add_matcher(
unsqueeze_matcher, constant_unsqueeze_callback, PassProperty::CHANGE_DYNAMIC_STATE);
}
......@@ -24,6 +24,62 @@
using namespace ngraph;
using namespace std;
TEST(constant_folding, constant_squeeze)
{
Shape shape_in{2, 4, 1};
Shape shape_out{2, 4};
Shape axes_shape{1};
vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
auto constant = make_shared<op::Constant>(element::f32, shape_in, values_in);
vector<int64_t> values_axes{2};
auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
auto squeeze = make_shared<op::Squeeze>(constant, constant_axes);
auto f = make_shared<Function>(squeeze, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Squeeze>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const = as_type_ptr<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
ASSERT_EQ(new_const->get_shape(), shape_out);
auto values_out = new_const->get_vector<float>();
ASSERT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
}
TEST(constant_folding, constant_unsqueeze)
{
Shape shape_in{2, 4};
Shape shape_out{2, 4, 1, 1};
Shape axes_shape{2};
vector<float> values_in{0, 1, 2, 3, 4, 5, 6, 7};
auto constant = make_shared<op::Constant>(element::f32, shape_in, values_in);
vector<int64_t> values_axes{2, 3};
auto constant_axes = op::Constant::create(element::i64, axes_shape, values_axes);
auto unsqueeze = make_shared<op::Unsqueeze>(constant, constant_axes);
auto f = make_shared<Function>(unsqueeze, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Unsqueeze>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const = as_type_ptr<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
ASSERT_EQ(new_const->get_shape(), shape_out);
auto values_out = new_const->get_vector<float>();
ASSERT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
}
TEST(constant_folding, constant_reshape)
{
Shape shape_in{2, 4};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment