Commit 09952c0b authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

ConstantFolding for Ceiling and Floor (#3320)

* CF for Sign, and extend element type capabilities for unary arithop CF

* CF for Ceiling and Floor

* Update CPU CF builders

* Update CPU CF builders

* Add tests for new CPU CF folders

* Add tests for recently added CPU CF functors

* Add tests for non-CPU ceiling/floor CF

* Unit tests
parent 831df41d
......@@ -22,6 +22,7 @@
#include "ngraph/op/add.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/ceiling.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
......@@ -29,6 +30,7 @@
#include "ngraph/op/divide.hpp"
#include "ngraph/op/equal.hpp"
#include "ngraph/op/experimental/shape_of.hpp"
#include "ngraph/op/floor.hpp"
#include "ngraph/op/greater.hpp"
#include "ngraph/op/greater_eq.hpp"
#include "ngraph/op/less.hpp"
......@@ -55,11 +57,13 @@
#include "ngraph/runtime/reference/add.hpp"
#include "ngraph/runtime/reference/and.hpp"
#include "ngraph/runtime/reference/broadcast.hpp"
#include "ngraph/runtime/reference/ceiling.hpp"
#include "ngraph/runtime/reference/concat.hpp"
#include "ngraph/runtime/reference/convert.hpp"
#include "ngraph/runtime/reference/dequantize.hpp"
#include "ngraph/runtime/reference/divide.hpp"
#include "ngraph/runtime/reference/equal.hpp"
#include "ngraph/runtime/reference/floor.hpp"
#include "ngraph/runtime/reference/greater.hpp"
#include "ngraph/runtime/reference/greater_eq.hpp"
#include "ngraph/runtime/reference/less.hpp"
......@@ -719,7 +723,8 @@ void pass::ConstantFolding::construct_constant_binary()
bool is_supported_unary_op(std::shared_ptr<Node> n)
{
return std::dynamic_pointer_cast<op::Abs>(n) || std::dynamic_pointer_cast<op::Negative>(n) ||
return std::dynamic_pointer_cast<op::Abs>(n) || std::dynamic_pointer_cast<op::Ceiling>(n) ||
std::dynamic_pointer_cast<op::Floor>(n) || std::dynamic_pointer_cast<op::Negative>(n) ||
std::dynamic_pointer_cast<op::Relu>(n) || std::dynamic_pointer_cast<op::Sign>(n) ||
std::dynamic_pointer_cast<op::Sqrt>(n);
}
......@@ -758,6 +763,16 @@ shared_ptr<op::Constant> fold_constant_unary(shared_ptr<op::Constant> constant,
runtime::reference::abs<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
}
else if (std::dynamic_pointer_cast<op::Ceiling>(unary))
{
runtime::reference::ceiling<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
}
else if (std::dynamic_pointer_cast<op::Floor>(unary))
{
runtime::reference::floor<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
}
else if (std::dynamic_pointer_cast<op::Negative>(unary))
{
runtime::reference::negate<T>(
......
......@@ -470,6 +470,18 @@ namespace ngraph
BUILD_UNARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::sqrt);
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Floor)
{
BUILD_UNARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::floor);
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Ceiling)
{
BUILD_UNARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::ceil);
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Equal)
{
......@@ -604,6 +616,8 @@ namespace ngraph
REGISTER_CF_BUILDER(Negative);
REGISTER_CF_BUILDER(Relu);
REGISTER_CF_BUILDER(Sqrt);
REGISTER_CF_BUILDER(Floor);
REGISTER_CF_BUILDER(Ceiling);
REGISTER_CF_BUILDER(Equal);
REGISTER_CF_BUILDER(NotEqual);
REGISTER_CF_BUILDER(Greater);
......
......@@ -667,6 +667,54 @@ TEST(constant_folding, const_or)
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_ceiling)
{
auto constant = op::Constant::create(
element::f32, Shape{2, 3}, vector<float>{0.0f, 0.1f, -0.1f, -2.5f, 2.5f, 3.0f});
auto ceil = make_shared<op::Ceiling>(constant);
auto f = make_shared<Function>(ceil, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Ceiling>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<float>();
vector<float> values_expected{0.0f, 1.0f, 0.0f, -2.0f, 3.0f, 3.0f};
ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
}
TEST(constant_folding, const_floor)
{
auto constant = op::Constant::create(
element::f32, Shape{2, 3}, vector<float>{0.0f, 0.1f, -0.1f, -2.5f, 2.5f, 3.0f});
auto floor = make_shared<op::Floor>(constant);
auto f = make_shared<Function>(floor, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Floor>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<float>();
vector<float> values_expected{0.0f, 0.0f, -1.0f, -3.0f, 2.0f, 3.0f};
ASSERT_TRUE(test::all_close_f(values_out, values_expected, MIN_FLOAT_TOLERANCE_BITS));
}
TEST(constant_folding, pass_property)
{
auto pass = std::make_shared<ngraph::pass::ConstantFolding>();
......
......@@ -1183,6 +1183,7 @@ TEST(cpu_test, constant_unary_binary)
vector<int> values_h{2, 2, 3, 3};
vector<char> values_i{0, 0, 1, 1};
vector<char> values_j{0, 1, 0, 1};
vector<float> values_k{-0.1f, 0.0f, -1.5f, 2.6f};
auto a = make_shared<op::Constant>(element::i32, shape_in, values_a);
auto b = make_shared<op::Constant>(element::i32, shape_in, values_b);
auto c = make_shared<op::Constant>(element::i32, shape_in, values_c);
......@@ -1193,6 +1194,7 @@ TEST(cpu_test, constant_unary_binary)
auto h = make_shared<op::Constant>(element::i32, shape_in, values_h);
auto i = make_shared<op::Constant>(element::boolean, shape_in, values_i);
auto j = make_shared<op::Constant>(element::boolean, shape_in, values_j);
auto k = make_shared<op::Constant>(element::f32, shape_in, values_k);
auto add = a + b;
auto sub = a - b;
......@@ -1214,27 +1216,14 @@ TEST(cpu_test, constant_unary_binary)
auto less_eq = make_shared<op::LessEq>(g, h);
auto logical_and = make_shared<op::And>(i, j);
auto logical_or = make_shared<op::Or>(i, j);
auto ceil = make_shared<op::Ceiling>(k);
auto floor = make_shared<op::Floor>(k);
auto func = make_shared<Function>(NodeVector{add,
sub,
mul,
divn,
min,
max,
absn,
neg,
sqrt,
relu,
sign,
equal,
not_equal,
greater,
greater_eq,
less,
less_eq,
logical_and,
logical_or},
ParameterVector{});
auto func = make_shared<Function>(
NodeVector{add, sub, mul, divn, min, max, absn,
neg, sqrt, relu, sign, equal, not_equal, greater,
greater_eq, less, less_eq, logical_and, logical_or, ceil, floor},
ParameterVector{});
auto func_error = make_shared<Function>(NodeVector{neg_sqrt}, ParameterVector{});
......@@ -1262,6 +1251,8 @@ TEST(cpu_test, constant_unary_binary)
ASSERT_EQ(count_ops_of_type<op::LessEq>(func), 0);
ASSERT_EQ(count_ops_of_type<op::And>(func), 0);
ASSERT_EQ(count_ops_of_type<op::Or>(func), 0);
ASSERT_EQ(count_ops_of_type<op::Ceiling>(func), 0);
ASSERT_EQ(count_ops_of_type<op::Floor>(func), 0);
//expected values
vector<int> add_expected{2, 4, 6, 8};
......@@ -1282,6 +1273,8 @@ TEST(cpu_test, constant_unary_binary)
vector<char> less_eq_expected{1, 1, 1, 0};
vector<char> and_expected{0, 0, 0, 1};
vector<char> or_expected{0, 1, 1, 1};
vector<float> ceil_expected{0.0f, 0.0f, -1.0f, 3.0f};
vector<float> floor_expected{-1.0f, 0.0f, -2.0f, 2.0f};
ASSERT_EQ(get_result_constant<int>(func, 0), add_expected);
ASSERT_EQ(get_result_constant<int>(func, 1), sub_expected);
......@@ -1302,6 +1295,10 @@ TEST(cpu_test, constant_unary_binary)
ASSERT_EQ(get_result_constant<char>(func, 16), less_eq_expected);
ASSERT_EQ(get_result_constant<char>(func, 17), and_expected);
ASSERT_EQ(get_result_constant<char>(func, 18), or_expected);
ASSERT_TRUE(test::all_close_f(
get_result_constant<float>(func, 19), ceil_expected, MIN_FLOAT_TOLERANCE_BITS));
ASSERT_TRUE(test::all_close_f(
get_result_constant<float>(func, 20), floor_expected, MIN_FLOAT_TOLERANCE_BITS));
ASSERT_ANY_THROW(pass_manager.run_passes(func_error));
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment