Commit 00a76f3b authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Robert Kimball

Constant Folding : Constant + Pad (#1528)

* constant + pad

* adding broadcast test back
parent 446cf07b
......@@ -20,10 +20,12 @@
#include "ngraph/graph_util.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/reference/broadcast.hpp"
#include "ngraph/runtime/reference/pad.hpp"
#include "ngraph/runtime/reference/reshape.hpp"
using namespace std;
......@@ -45,6 +47,78 @@ shared_ptr<op::Constant> make_constant_reshape(shared_ptr<op::Constant> constant
return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec);
}
template <class T>
shared_ptr<op::Constant> make_constant_pad(shared_ptr<op::Constant> constant,
shared_ptr<op::Pad> pad)
{
auto out_shape = pad->get_shape();
vector<T> out_vec(shape_size(out_shape));
auto pad_value = std::dynamic_pointer_cast<op::Constant>(pad->get_argument(1));
runtime::reference::pad<T>(constant->get_vector<T>().data(),
pad_value->get_vector<T>().data(),
out_vec.data(),
constant->get_shape(),
out_shape,
pad->get_padding_below(),
pad->get_padding_above(),
pad->get_padding_interior());
return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec);
}
void ngraph::pass::ConstantFolding::construct_constant_pad()
{
auto is_constant = pattern::has_class<op::Constant>();
auto constant_label = make_shared<pattern::op::Label>(element::f32, Shape{6}, is_constant);
auto pad_value_label = make_shared<pattern::op::Label>(element::f32, Shape{}, is_constant);
Shape padding_below{0};
Shape padding_above{0};
Shape padding_interior{0};
auto pad = make_shared<op::Pad>(
constant_label, pad_value_label, padding_below, padding_above, padding_interior);
auto constant_pad_callback = [constant_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_pad_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto constant_match = dynamic_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto pad_match = dynamic_pointer_cast<op::Pad>(m.get_match_root());
auto type = constant_match->get_element_type();
if (type == element::i32)
{
replace_node(m.get_match_root(), make_constant_pad<int>(constant_match, pad_match));
return true;
}
else if (type == element::i8)
{
replace_node(m.get_match_root(), make_constant_pad<int8_t>(constant_match, pad_match));
return true;
}
else if (type == element::f32)
{
replace_node(m.get_match_root(), make_constant_pad<float>(constant_match, pad_match));
return true;
}
else if (type == element::f64)
{
replace_node(m.get_match_root(), make_constant_pad<double>(constant_match, pad_match));
return true;
}
return false;
};
auto pad_matcher = make_shared<pattern::Matcher>(pad, constant_pad_callback);
this->add_matcher(pad_matcher);
}
void ngraph::pass::ConstantFolding::construct_constant_reshape()
{
auto constant_label = make_shared<pattern::op::Label>(
......
......@@ -34,9 +34,11 @@ public:
{
construct_constant_reshape();
construct_constant_broadcast();
construct_constant_pad();
}
private:
void construct_constant_reshape();
void construct_constant_broadcast();
void construct_constant_pad();
};
......@@ -99,3 +99,67 @@ TEST(constant_folding, constant_broadcast)
vector<int> values_permute{0, 0, 0, 0, 1, 1, 1, 1};
ASSERT_EQ(values_permute, values_out);
}
TEST(constant_folding, constant_pad_exterior)
{
Shape shape_in{2};
vector<int> values_in{777, 888};
auto constant = make_shared<op::Constant>(element::i32, shape_in, values_in);
auto pad_value = make_shared<op::Constant>(element::i32, Shape{}, vector<int>{111});
Shape padding_below{1};
Shape padding_above{2};
Shape padding_interior{0};
auto broadcast =
make_shared<op::Pad>(constant, pad_value, padding_below, padding_above, padding_interior);
auto f = make_shared<Function>(broadcast, op::ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Pad>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int>();
vector<int> padded_values{111, 777, 888, 111, 111};
ASSERT_EQ(padded_values, values_out);
}
TEST(constant_folding, constant_pad_interior)
{
Shape shape_in{2};
vector<int> values_in{777, 888};
auto constant = make_shared<op::Constant>(element::i32, shape_in, values_in);
auto pad_value = make_shared<op::Constant>(element::i32, Shape{}, vector<int>{111});
Shape padding_below{0};
Shape padding_above{0};
Shape padding_interior{3};
auto broadcast =
make_shared<op::Pad>(constant, pad_value, padding_below, padding_above, padding_interior);
auto f = make_shared<Function>(broadcast, op::ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Pad>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int>();
vector<int> padded_values{777, 111, 111, 111, 888};
ASSERT_EQ(padded_values, values_out);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment