Commit 97e3559f authored by Nick Korovaiko's avatar Nick Korovaiko Committed by Scott Cyphers

Fold some unary and binary ops (#1719)

*  unary, binary folding

* fix divide wrong template args

* add tests

* fix merge breaks
parent ec9854a2
......@@ -18,17 +18,33 @@
#include "constant_folding.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/abs.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/runtime/reference/abs.hpp"
#include "ngraph/runtime/reference/add.hpp"
#include "ngraph/runtime/reference/broadcast.hpp"
#include "ngraph/runtime/reference/dequantize.hpp"
#include "ngraph/runtime/reference/divide.hpp"
#include "ngraph/runtime/reference/maximum.hpp"
#include "ngraph/runtime/reference/minimum.hpp"
#include "ngraph/runtime/reference/multiply.hpp"
#include "ngraph/runtime/reference/negate.hpp"
#include "ngraph/runtime/reference/pad.hpp"
#include "ngraph/runtime/reference/reshape.hpp"
#include "ngraph/runtime/reference/subtract.hpp"
using namespace std;
using namespace ngraph;
......@@ -234,6 +250,214 @@ void ngraph::pass::ConstantFolding::construct_constant_broadcast()
this->add_matcher(broadcast_matcher);
}
template <class T>
shared_ptr<op::Constant> make_constant_binary(shared_ptr<op::Constant> a,
shared_ptr<op::Constant> b,
shared_ptr<Node> binary)
{
auto out_shape = binary->get_shape();
vector<T> out_vec(shape_size(out_shape));
if (std::dynamic_pointer_cast<op::Add>(binary))
{
runtime::reference::add<T>(a->get_vector<T>().data(),
b->get_vector<T>().data(),
out_vec.data(),
shape_size(out_shape));
}
else if (std::dynamic_pointer_cast<op::Subtract>(binary))
{
runtime::reference::subtract<T>(a->get_vector<T>().data(),
b->get_vector<T>().data(),
out_vec.data(),
shape_size(out_shape));
}
else if (std::dynamic_pointer_cast<op::Multiply>(binary))
{
runtime::reference::multiply<T>(a->get_vector<T>().data(),
b->get_vector<T>().data(),
out_vec.data(),
shape_size(out_shape));
}
else if (std::dynamic_pointer_cast<op::Divide>(binary))
{
runtime::reference::divide<T>(a->get_vector<T>().data(),
b->get_vector<T>().data(),
out_vec.data(),
shape_size(out_shape));
}
else if (std::dynamic_pointer_cast<op::Minimum>(binary))
{
runtime::reference::minimum<T>(a->get_vector<T>().data(),
b->get_vector<T>().data(),
out_vec.data(),
shape_size(out_shape));
}
else if (std::dynamic_pointer_cast<op::Maximum>(binary))
{
runtime::reference::maximum<T>(a->get_vector<T>().data(),
b->get_vector<T>().data(),
out_vec.data(),
shape_size(out_shape));
}
else
{
NGRAPH_ASSERT(false)
<< "make_constant_binary must be consistent with is_supported_binary_op";
}
return make_shared<op::Constant>(a->get_element_type(), out_shape, out_vec);
}
bool is_supported_binary_op(std::shared_ptr<Node> n)
{
return (std::dynamic_pointer_cast<op::Add>(n) || std::dynamic_pointer_cast<op::Subtract>(n) ||
std::dynamic_pointer_cast<op::Multiply>(n) ||
std::dynamic_pointer_cast<op::Divide>(n) || std::dynamic_pointer_cast<op::Maximum>(n) ||
std::dynamic_pointer_cast<op::Minimum>(n));
}
void ngraph::pass::ConstantFolding::construct_constant_binary()
{
auto a = make_shared<pattern::op::Label>(
element::f32, Shape{2, 4}, pattern::has_class<op::Constant>());
auto b = make_shared<pattern::op::Label>(
element::f32, Shape{2, 4}, pattern::has_class<op::Constant>());
auto is_bea = pattern::has_class<op::util::BinaryElementwiseArithmetic>();
auto bea = std::make_shared<pattern::op::Any>(a, is_bea, NodeVector{a, b});
auto constant_binary_callback = [a, b](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_binary_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto a_match = dynamic_pointer_cast<op::Constant>(pattern_map[a]);
auto b_match = dynamic_pointer_cast<op::Constant>(pattern_map[b]);
auto binary_match = m.get_match_root();
if (!is_supported_binary_op(binary_match))
{
return false;
}
auto type = a_match->get_element_type();
if (type == element::i32)
{
replace_node(m.get_match_root(),
make_constant_binary<int>(a_match, b_match, binary_match));
return true;
}
else if (type == element::i8)
{
replace_node(m.get_match_root(),
make_constant_binary<int8_t>(a_match, b_match, binary_match));
return true;
}
else if (type == element::f32)
{
replace_node(m.get_match_root(),
make_constant_binary<float>(a_match, b_match, binary_match));
return true;
}
else if (type == element::f64)
{
replace_node(m.get_match_root(),
make_constant_binary<double>(a_match, b_match, binary_match));
return true;
}
return false;
};
auto reshape_matcher = make_shared<pattern::Matcher>(bea, constant_binary_callback);
this->add_matcher(reshape_matcher);
}
bool is_supported_unary_op(std::shared_ptr<Node> n)
{
return std::dynamic_pointer_cast<op::Abs>(n) || std::dynamic_pointer_cast<op::Negative>(n);
}
template <class T>
shared_ptr<op::Constant> make_constant_unary(shared_ptr<op::Constant> constant,
shared_ptr<Node> unary)
{
auto out_shape = unary->get_shape();
vector<T> out_vec(shape_size(out_shape));
if (std::dynamic_pointer_cast<op::Abs>(unary))
{
runtime::reference::abs<T>(
constant->get_vector<T>().data(), out_vec.data(), shape_size(out_shape));
}
else if (std::dynamic_pointer_cast<op::Negative>(unary))
{
runtime::reference::negate<T>(
constant->get_vector<T>().data(), out_vec.data(), shape_size(out_shape));
}
else
{
NGRAPH_ASSERT(false) << "must be consistent with is_supported_unary_op";
}
return make_shared<op::Constant>(constant->get_element_type(), out_shape, out_vec);
}
void ngraph::pass::ConstantFolding::construct_constant_unary()
{
auto constant_label = make_shared<pattern::op::Label>(
element::f32, Shape{2, 4}, pattern::has_class<op::Constant>());
auto is_uea = pattern::has_class<op::util::UnaryElementwiseArithmetic>();
auto uea =
std::make_shared<pattern::op::Any>(constant_label, is_uea, NodeVector{constant_label});
auto constant_unary_callback = [constant_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_reshape_callback against node = "
<< m.get_match_root()->get_name();
auto pattern_map = m.get_pattern_map();
auto constant_match = dynamic_pointer_cast<op::Constant>(pattern_map[constant_label]);
auto unary_match = m.get_match_root();
if (!is_supported_unary_op(unary_match))
{
return false;
}
auto type = constant_match->get_element_type();
if (type == element::i32)
{
replace_node(m.get_match_root(), make_constant_unary<int>(constant_match, unary_match));
return true;
}
else if (type == element::i8)
{
replace_node(m.get_match_root(),
make_constant_unary<int8_t>(constant_match, unary_match));
return true;
}
else if (type == element::f32)
{
replace_node(m.get_match_root(),
make_constant_unary<float>(constant_match, unary_match));
return true;
}
else if (type == element::f64)
{
replace_node(m.get_match_root(),
make_constant_unary<double>(constant_match, unary_match));
return true;
}
return false;
};
auto reshape_matcher = make_shared<pattern::Matcher>(uea, constant_unary_callback);
this->add_matcher(reshape_matcher);
}
template <class QUANT, class REAL>
shared_ptr<op::Constant> make_constant_dequantize(shared_ptr<op::Constant> constant,
shared_ptr<op::Dequantize> dequant,
......
......@@ -33,7 +33,9 @@ class ngraph::pass::ConstantFolding : public ngraph::pass::GraphRewrite
RESHAPE,
BROADCAST,
PAD,
DEQUANTIZE
DEQUANTIZE,
UNARY,
BINARY
};
public:
......@@ -43,6 +45,8 @@ public:
construct_constant_reshape();
construct_constant_broadcast();
construct_constant_pad();
construct_constant_unary();
construct_constant_binary();
construct_constant_dequantize();
}
......@@ -58,6 +62,8 @@ public:
case CFTransformations::RESHAPE: construct_constant_reshape(); break;
case CFTransformations::BROADCAST: construct_constant_broadcast(); break;
case CFTransformations::PAD: construct_constant_pad(); break;
case CFTransformations::UNARY: construct_constant_unary(); break;
case CFTransformations::BINARY: construct_constant_binary(); break;
case CFTransformations::DEQUANTIZE: construct_constant_dequantize(); break;
}
}
......@@ -67,5 +73,7 @@ private:
void construct_constant_reshape();
void construct_constant_broadcast();
void construct_constant_pad();
void construct_constant_unary();
void construct_constant_binary();
void construct_constant_dequantize();
};
......@@ -164,6 +164,59 @@ TEST(constant_folding, constant_pad_interior)
ASSERT_EQ(padded_values, values_out);
}
template <typename T>
static std::vector<T> get_result_constant(std::shared_ptr<Function> f, size_t pos)
{
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(pos)->get_argument(0));
return new_const->get_vector<T>();
}
TEST(constant_folding, constant_unary_binary)
{
Shape shape_in{4};
vector<int> values_a{1, 2, 3, 4};
vector<int> values_b{1, 2, 3, 4};
vector<int> values_c{-1, -1, -1, -1};
auto a = make_shared<op::Constant>(element::i32, shape_in, values_a);
auto b = make_shared<op::Constant>(element::i32, shape_in, values_b);
auto c = make_shared<op::Constant>(element::i32, shape_in, values_c);
auto add = a + b;
auto sub = a - b;
auto mul = a * b;
auto divn = a / b;
auto min = make_shared<op::Minimum>(c, a);
auto max = make_shared<op::Maximum>(a, c);
auto absn = make_shared<op::Abs>(c);
auto neg = make_shared<op::Negative>(c);
auto f = make_shared<Function>(NodeVector{add, sub, mul, divn, min, max, absn, neg},
op::ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
//expected values
vector<int> add_expected{2, 4, 6, 8};
vector<int> sub_expected{0, 0, 0, 0};
vector<int> mul_expected{1, 4, 9, 16};
vector<int> div_expected{1, 1, 1, 1};
vector<int> min_expected{-1, -1, -1, -1};
vector<int> max_expected{1, 2, 3, 4};
vector<int> abs_neg_expected{1, 1, 1, 1};
ASSERT_EQ(get_result_constant<int>(f, 0), add_expected);
ASSERT_EQ(get_result_constant<int>(f, 1), sub_expected);
ASSERT_EQ(get_result_constant<int>(f, 2), mul_expected);
ASSERT_EQ(get_result_constant<int>(f, 3), div_expected);
ASSERT_EQ(get_result_constant<int>(f, 4), min_expected);
ASSERT_EQ(get_result_constant<int>(f, 5), max_expected);
ASSERT_EQ(get_result_constant<int>(f, 6), abs_neg_expected);
ASSERT_EQ(get_result_constant<int>(f, 7), abs_neg_expected);
}
TEST(constant_folding, const_dequantize)
{
Shape input_shape{12};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment