Commit 35b04e6a authored by Adam Straw's avatar Adam Straw Committed by Nick Korovaiko

constant broadcast folding (#1139)

* constant broadcast folding

* code review feedback
parent 13f00048
...@@ -14,12 +14,16 @@ ...@@ -14,12 +14,16 @@
* limitations under the License. * limitations under the License.
*******************************************************************************/ *******************************************************************************/
#include <stdint.h>
#include "constant_folding.hpp" #include "constant_folding.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp" #include "ngraph/op/constant.hpp"
#include "ngraph/op/reshape.hpp" #include "ngraph/op/reshape.hpp"
#include "ngraph/pattern/matcher.hpp" #include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp" #include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/reference/broadcast.hpp"
#include "ngraph/runtime/reference/reshape.hpp" #include "ngraph/runtime/reference/reshape.hpp"
using namespace std; using namespace std;
...@@ -48,6 +52,9 @@ void ngraph::pass::ConstantFolding::construct_constant_reshape() ...@@ -48,6 +52,9 @@ void ngraph::pass::ConstantFolding::construct_constant_reshape()
auto reshape = make_shared<op::Reshape>(constant_label, AxisVector{0, 1}, Shape{2, 4, 1}); auto reshape = make_shared<op::Reshape>(constant_label, AxisVector{0, 1}, Shape{2, 4, 1});
auto constant_reshape_callback = [constant_label](pattern::Matcher& m) { auto constant_reshape_callback = [constant_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_reshape_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map(); auto pattern_map = m.get_pattern_map();
auto constant_match = dynamic_pointer_cast<op::Constant>(pattern_map[constant_label]); auto constant_match = dynamic_pointer_cast<op::Constant>(pattern_map[constant_label]);
...@@ -63,7 +70,7 @@ void ngraph::pass::ConstantFolding::construct_constant_reshape() ...@@ -63,7 +70,7 @@ void ngraph::pass::ConstantFolding::construct_constant_reshape()
else if (type == element::i8) else if (type == element::i8)
{ {
replace_node(m.get_match_root(), replace_node(m.get_match_root(),
make_constant_reshape<signed char>(constant_match, reshape_match)); make_constant_reshape<int8_t>(constant_match, reshape_match));
return true; return true;
} }
else if (type == element::f32) else if (type == element::f32)
...@@ -85,3 +92,68 @@ void ngraph::pass::ConstantFolding::construct_constant_reshape() ...@@ -85,3 +92,68 @@ void ngraph::pass::ConstantFolding::construct_constant_reshape()
auto reshape_matcher = make_shared<pattern::Matcher>(reshape, constant_reshape_callback); auto reshape_matcher = make_shared<pattern::Matcher>(reshape, constant_reshape_callback);
this->add_matcher(reshape_matcher); this->add_matcher(reshape_matcher);
} }
template <class T>
shared_ptr<op::Constant> make_constant_broadcast(shared_ptr<op::Constant> constant,
shared_ptr<op::Broadcast> broadcast)
{
auto out_shape = broadcast->get_shape();
vector<T> out_vec(shape_size(out_shape));
runtime::reference::broadcast<T>(constant->get_vector<T>().data(),
out_vec.data(),
constant->get_shape(),
out_shape,
broadcast->get_broadcast_axes());
return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec);
}
void ngraph::pass::ConstantFolding::construct_constant_broadcast()
{
auto constant_label =
make_shared<pattern::op::Label>(element::f32, Shape{2}, pattern::has_class<op::Constant>());
auto broadcast = make_shared<op::Broadcast>(constant_label, Shape{2, 4}, AxisSet{1});
auto constant_broadcast_callback = [constant_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_broadcast_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto constant_match = dynamic_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto broadcast_match = dynamic_pointer_cast<op::Broadcast>(m.get_match_root());
auto type = constant_match->get_element_type();
if (type == element::i32)
{
replace_node(m.get_match_root(),
make_constant_broadcast<int>(constant_match, broadcast_match));
return true;
}
else if (type == element::i8)
{
replace_node(m.get_match_root(),
make_constant_broadcast<int8_t>(constant_match, broadcast_match));
return true;
}
else if (type == element::f32)
{
replace_node(m.get_match_root(),
make_constant_broadcast<float>(constant_match, broadcast_match));
return true;
}
else if (type == element::f64)
{
replace_node(m.get_match_root(),
make_constant_broadcast<double>(constant_match, broadcast_match));
return true;
}
return false;
};
auto broadcast_matcher = make_shared<pattern::Matcher>(broadcast, constant_broadcast_callback);
this->add_matcher(broadcast_matcher);
}
...@@ -33,8 +33,10 @@ public: ...@@ -33,8 +33,10 @@ public:
: GraphRewrite() : GraphRewrite()
{ {
construct_constant_reshape(); construct_constant_reshape();
construct_constant_broadcast();
} }
private: private:
void construct_constant_reshape(); void construct_constant_reshape();
void construct_constant_broadcast();
}; };
...@@ -73,3 +73,29 @@ TEST(constant_folding, constant_reshape_permute) ...@@ -73,3 +73,29 @@ TEST(constant_folding, constant_reshape_permute)
vector<double> values_permute{0, 4, 1, 5, 2, 6, 3, 7}; vector<double> values_permute{0, 4, 1, 5, 2, 6, 3, 7};
ASSERT_EQ(values_permute, values_out); ASSERT_EQ(values_permute, values_out);
} }
TEST(constant_folding, constant_broadcast)
{
Shape shape_in{2};
Shape shape_out{2, 4};
vector<int> values_in{0, 1};
auto constant = make_shared<op::Constant>(element::i32, shape_in, values_in);
auto broadcast = make_shared<op::Broadcast>(constant, shape_out, AxisSet{1});
auto f = make_shared<Function>(broadcast, op::ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Broadcast>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<int>();
vector<int> values_permute{0, 0, 0, 0, 1, 1, 1, 1};
ASSERT_EQ(values_permute, values_out);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment