Commit 24b73881 authored by Fabian Boemer's avatar Fabian Boemer Committed by Scott Cyphers

Added power constant folding (#3725)

* Added power constant folding

* Style apply

* Removed python operator pow
parent 838a6610
......@@ -28,6 +28,7 @@
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/or.hpp"
#include "ngraph/op/power.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/xor.hpp"
#include "ngraph/runtime/reference/add.hpp"
......@@ -43,6 +44,7 @@
#include "ngraph/runtime/reference/multiply.hpp"
#include "ngraph/runtime/reference/not_equal.hpp"
#include "ngraph/runtime/reference/or.hpp"
#include "ngraph/runtime/reference/power.hpp"
#include "ngraph/runtime/reference/subtract.hpp"
#include "ngraph/runtime/reference/xor.hpp"
......@@ -304,6 +306,20 @@ shared_ptr<op::Constant> fold_constant_binary_arithmetic(shared_ptr<op::Constant
multiply_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto power_node = as_type_ptr<op::Power>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
"Input/output types do not match");
vector<Tout> out_vec(shape_size(out_shape));
shared_ptr<op::Power> powop = as_type_ptr<op::Power>(binary);
runtime::reference::power<Tin>(a->get_data_ptr<Tin>(),
b->get_data_ptr<Tin>(),
out_vec.data(),
a->get_shape(),
b->get_shape(),
power_node->get_autob());
return make_shared<op::Constant>(binary->get_element_type(), out_shape, out_vec);
}
else if (auto subtract_node = as_type_ptr<op::Subtract>(binary))
{
NGRAPH_CHECK(element::from<Tin>() == element::from<Tout>(),
......@@ -352,7 +368,8 @@ bool is_supported_binary_op(std::shared_ptr<Node> n)
is_type<op::Equal>(n) || is_type<op::Greater>(n) || is_type<op::GreaterEq>(n) ||
is_type<op::Less>(n) || is_type<op::LessEq>(n) || is_type<op::Maximum>(n) ||
is_type<op::Minimum>(n) || is_type<op::Multiply>(n) || is_type<op::NotEqual>(n) ||
is_type<op::Or>(n) || is_type<op::Subtract>(n) || is_type<op::Xor>(n));
is_type<op::Or>(n) || is_type<op::Power>(n) || is_type<op::Subtract>(n) ||
is_type<op::Xor>(n));
}
void pass::ConstantFolding::construct_constant_binary()
......
......@@ -185,6 +185,7 @@ TEST(constant_folding, constant_unary_binary)
auto sub = a - b;
auto mul = a * b;
auto divn = a / b;
auto pow = make_shared<op::Power>(a, b);
auto min = make_shared<op::Minimum>(c, a);
auto max = make_shared<op::Maximum>(a, c);
auto absn = make_shared<op::Abs>(c);
......@@ -194,6 +195,7 @@ TEST(constant_folding, constant_unary_binary)
auto sub_autob_numpy = make_shared<op::Subtract>(a, e, op::AutoBroadcastType::NUMPY);
auto mul_autob_numpy = make_shared<op::Multiply>(a, e, op::AutoBroadcastType::NUMPY);
auto div_autob_numpy = make_shared<op::Divide>(a, g, op::AutoBroadcastType::NUMPY);
auto pow_autob_numpy = make_shared<op::Power>(a, g, op::AutoBroadcastType::NUMPY);
auto min_autob_numpy = make_shared<op::Minimum>(a, f, op::AutoBroadcastType::NUMPY);
auto max_autob_numpy = make_shared<op::Maximum>(a, f, op::AutoBroadcastType::NUMPY);
auto equal_autob_numpy = make_shared<op::Equal>(a, g, op::AutoBroadcastType::NUMPY);
......@@ -212,6 +214,7 @@ TEST(constant_folding, constant_unary_binary)
sub,
mul,
divn,
pow,
min,
max,
absn,
......@@ -221,6 +224,7 @@ TEST(constant_folding, constant_unary_binary)
sub_autob_numpy,
mul_autob_numpy,
div_autob_numpy,
pow_autob_numpy,
min_autob_numpy,
max_autob_numpy,
equal_autob_numpy,
......@@ -244,6 +248,7 @@ TEST(constant_folding, constant_unary_binary)
vector<int> sub_expected{0, 0, 0, 0};
vector<int> mul_expected{1, 4, 9, 16};
vector<int> div_expected{1, 1, 1, 1};
vector<int> pow_expected{1, 4, 27, 256};
vector<int> min_expected{-1, -1, -1, -1};
vector<int> max_expected{1, 2, 3, 4};
vector<int> abs_neg_expected{1, 1, 1, 1};
......@@ -252,6 +257,7 @@ TEST(constant_folding, constant_unary_binary)
vector<int> sub_autob_numpy_expected{-4, -4, -2, -2};
vector<int> mul_autob_numpy_expected{5, 12, 15, 24};
vector<int> div_autob_numpy_expected{1, 0, 3, 1};
vector<int> pow_autob_numpy_expected{1, 16, 3, 256};
vector<int> min_autob_numpy_expected{0, 2, 0, 4};
vector<int> max_autob_numpy_expected{1, 10, 3, 10};
vector<char> equal_autob_numpy_expected{1, 0, 0, 1};
......@@ -268,26 +274,28 @@ TEST(constant_folding, constant_unary_binary)
ASSERT_EQ(get_result_constant<int>(func, 1), sub_expected);
ASSERT_EQ(get_result_constant<int>(func, 2), mul_expected);
ASSERT_EQ(get_result_constant<int>(func, 3), div_expected);
ASSERT_EQ(get_result_constant<int>(func, 4), min_expected);
ASSERT_EQ(get_result_constant<int>(func, 5), max_expected);
ASSERT_EQ(get_result_constant<int>(func, 6), abs_neg_expected);
ASSERT_EQ(get_result_constant<int>(func, 4), pow_expected);
ASSERT_EQ(get_result_constant<int>(func, 5), min_expected);
ASSERT_EQ(get_result_constant<int>(func, 6), max_expected);
ASSERT_EQ(get_result_constant<int>(func, 7), abs_neg_expected);
ASSERT_EQ(get_result_constant<int>(func, 8), sqrt_expected);
ASSERT_EQ(get_result_constant<int>(func, 9), add_autob_numpy_expected);
ASSERT_EQ(get_result_constant<int>(func, 10), sub_autob_numpy_expected);
ASSERT_EQ(get_result_constant<int>(func, 11), mul_autob_numpy_expected);
ASSERT_EQ(get_result_constant<int>(func, 12), div_autob_numpy_expected);
ASSERT_EQ(get_result_constant<int>(func, 13), min_autob_numpy_expected);
ASSERT_EQ(get_result_constant<int>(func, 14), max_autob_numpy_expected);
ASSERT_EQ(get_result_constant<char>(func, 15), equal_autob_numpy_expected);
ASSERT_EQ(get_result_constant<char>(func, 16), not_equal_autob_numpy_expected);
ASSERT_EQ(get_result_constant<char>(func, 17), greater_autob_numpy_expected);
ASSERT_EQ(get_result_constant<char>(func, 18), greater_eq_autob_numpy_expected);
ASSERT_EQ(get_result_constant<char>(func, 19), less_autob_numpy_expected);
ASSERT_EQ(get_result_constant<char>(func, 20), less_eq_autob_numpy_expected);
ASSERT_EQ(get_result_constant<char>(func, 21), logical_and_autob_numpy_expected);
ASSERT_EQ(get_result_constant<char>(func, 22), logical_or_autob_numpy_expected);
ASSERT_EQ(get_result_constant<char>(func, 23), logical_xor_autob_numpy_expected);
ASSERT_EQ(get_result_constant<int>(func, 8), abs_neg_expected);
ASSERT_EQ(get_result_constant<int>(func, 9), sqrt_expected);
ASSERT_EQ(get_result_constant<int>(func, 10), add_autob_numpy_expected);
ASSERT_EQ(get_result_constant<int>(func, 11), sub_autob_numpy_expected);
ASSERT_EQ(get_result_constant<int>(func, 12), mul_autob_numpy_expected);
ASSERT_EQ(get_result_constant<int>(func, 13), div_autob_numpy_expected);
ASSERT_EQ(get_result_constant<int>(func, 14), pow_autob_numpy_expected);
ASSERT_EQ(get_result_constant<int>(func, 15), min_autob_numpy_expected);
ASSERT_EQ(get_result_constant<int>(func, 16), max_autob_numpy_expected);
ASSERT_EQ(get_result_constant<char>(func, 17), equal_autob_numpy_expected);
ASSERT_EQ(get_result_constant<char>(func, 18), not_equal_autob_numpy_expected);
ASSERT_EQ(get_result_constant<char>(func, 19), greater_autob_numpy_expected);
ASSERT_EQ(get_result_constant<char>(func, 20), greater_eq_autob_numpy_expected);
ASSERT_EQ(get_result_constant<char>(func, 21), less_autob_numpy_expected);
ASSERT_EQ(get_result_constant<char>(func, 22), less_eq_autob_numpy_expected);
ASSERT_EQ(get_result_constant<char>(func, 23), logical_and_autob_numpy_expected);
ASSERT_EQ(get_result_constant<char>(func, 24), logical_or_autob_numpy_expected);
ASSERT_EQ(get_result_constant<char>(func, 25), logical_xor_autob_numpy_expected);
ASSERT_ANY_THROW(pass_manager.run_passes(func_error));
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment