Commit 5ece6de2 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

ConstantFolding for Not (#3326)

* CF for Sign, and extend element type capabilities for unary arithop CF

* CF for Ceiling and Floor

* Update CPU CF builders

* Update CPU CF builders

* CF for Not

* Add tests for new CPU CF folders

* Add tests for recently added CPU CF functors

* Add tests for non-CPU ceiling/floor CF

* Unit tests

* Add test for CPU folder
parent 09952c0b
...@@ -39,6 +39,7 @@ ...@@ -39,6 +39,7 @@
#include "ngraph/op/minimum.hpp" #include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp" #include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp" #include "ngraph/op/negative.hpp"
#include "ngraph/op/not.hpp"
#include "ngraph/op/not_equal.hpp" #include "ngraph/op/not_equal.hpp"
#include "ngraph/op/or.hpp" #include "ngraph/op/or.hpp"
#include "ngraph/op/pad.hpp" #include "ngraph/op/pad.hpp"
...@@ -72,6 +73,7 @@ ...@@ -72,6 +73,7 @@
#include "ngraph/runtime/reference/minimum.hpp" #include "ngraph/runtime/reference/minimum.hpp"
#include "ngraph/runtime/reference/multiply.hpp" #include "ngraph/runtime/reference/multiply.hpp"
#include "ngraph/runtime/reference/negate.hpp" #include "ngraph/runtime/reference/negate.hpp"
#include "ngraph/runtime/reference/not.hpp"
#include "ngraph/runtime/reference/not_equal.hpp" #include "ngraph/runtime/reference/not_equal.hpp"
#include "ngraph/runtime/reference/or.hpp" #include "ngraph/runtime/reference/or.hpp"
#include "ngraph/runtime/reference/pad.hpp" #include "ngraph/runtime/reference/pad.hpp"
...@@ -725,8 +727,8 @@ bool is_supported_unary_op(std::shared_ptr<Node> n) ...@@ -725,8 +727,8 @@ bool is_supported_unary_op(std::shared_ptr<Node> n)
{ {
return std::dynamic_pointer_cast<op::Abs>(n) || std::dynamic_pointer_cast<op::Ceiling>(n) || return std::dynamic_pointer_cast<op::Abs>(n) || std::dynamic_pointer_cast<op::Ceiling>(n) ||
std::dynamic_pointer_cast<op::Floor>(n) || std::dynamic_pointer_cast<op::Negative>(n) || std::dynamic_pointer_cast<op::Floor>(n) || std::dynamic_pointer_cast<op::Negative>(n) ||
std::dynamic_pointer_cast<op::Relu>(n) || std::dynamic_pointer_cast<op::Sign>(n) || std::dynamic_pointer_cast<op::Not>(n) || std::dynamic_pointer_cast<op::Relu>(n) ||
std::dynamic_pointer_cast<op::Sqrt>(n); std::dynamic_pointer_cast<op::Sign>(n) || std::dynamic_pointer_cast<op::Sqrt>(n);
} }
template <class T> template <class T>
...@@ -778,6 +780,11 @@ shared_ptr<op::Constant> fold_constant_unary(shared_ptr<op::Constant> constant, ...@@ -778,6 +780,11 @@ shared_ptr<op::Constant> fold_constant_unary(shared_ptr<op::Constant> constant,
runtime::reference::negate<T>( runtime::reference::negate<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape)); constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
} }
else if (std::dynamic_pointer_cast<op::Not>(unary))
{
runtime::reference::logical_not<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
}
else if (std::dynamic_pointer_cast<op::Relu>(unary)) else if (std::dynamic_pointer_cast<op::Relu>(unary))
{ {
runtime::reference::relu<T>( runtime::reference::relu<T>(
...@@ -806,9 +813,11 @@ void pass::ConstantFolding::construct_constant_unary() ...@@ -806,9 +813,11 @@ void pass::ConstantFolding::construct_constant_unary()
{ {
auto constant_label = make_shared<pattern::op::Label>( auto constant_label = make_shared<pattern::op::Label>(
element::f32, Shape{2, 4}, pattern::has_class<op::Constant>()); element::f32, Shape{2, 4}, pattern::has_class<op::Constant>());
auto is_uea = pattern::has_class<op::util::UnaryElementwiseArithmetic>(); auto is_ue = [](std::shared_ptr<Node> n) {
auto uea = return (pattern::has_class<op::util::UnaryElementwiseArithmetic>()(n) ||
std::make_shared<pattern::op::Any>(constant_label, is_uea, NodeVector{constant_label}); pattern::has_class<op::Not>()(n));
};
auto ue = std::make_shared<pattern::op::Any>(constant_label, is_ue, NodeVector{constant_label});
auto constant_unary_callback = [&, constant_label](pattern::Matcher& m) { auto constant_unary_callback = [&, constant_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_unary_callback against node = " NGRAPH_DEBUG << "In callback for constant_unary_callback against node = "
...@@ -890,7 +899,7 @@ void pass::ConstantFolding::construct_constant_unary() ...@@ -890,7 +899,7 @@ void pass::ConstantFolding::construct_constant_unary()
return true; return true;
}; };
auto reshape_matcher = make_shared<pattern::Matcher>(uea, "ConstantFolding.ConstantUnary"); auto reshape_matcher = make_shared<pattern::Matcher>(ue, "ConstantFolding.ConstantUnary");
this->add_matcher(reshape_matcher, constant_unary_callback, PassProperty::REQUIRE_STATIC_SHAPE); this->add_matcher(reshape_matcher, constant_unary_callback, PassProperty::REQUIRE_STATIC_SHAPE);
} }
......
...@@ -550,6 +550,12 @@ namespace ngraph ...@@ -550,6 +550,12 @@ namespace ngraph
BUILD_UNARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::sign); BUILD_UNARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::sign);
} }
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Not)
{
BUILD_UNARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::logical_not);
}
#define TI(x) type_index(typeid(x)) #define TI(x) type_index(typeid(x))
BuildOpMap& GetGlobalBuildDispatcher() BuildOpMap& GetGlobalBuildDispatcher()
...@@ -627,6 +633,7 @@ namespace ngraph ...@@ -627,6 +633,7 @@ namespace ngraph
REGISTER_CF_BUILDER(And); REGISTER_CF_BUILDER(And);
REGISTER_CF_BUILDER(Or); REGISTER_CF_BUILDER(Or);
REGISTER_CF_BUILDER(Sign); REGISTER_CF_BUILDER(Sign);
REGISTER_CF_BUILDER(Not);
} }
} }
} }
...@@ -459,6 +459,30 @@ TEST(constant_folding, const_concat) ...@@ -459,6 +459,30 @@ TEST(constant_folding, const_concat)
ASSERT_EQ(values_expected, values_out); ASSERT_EQ(values_expected, values_out);
} }
TEST(constant_folding, const_not)
{
auto constant =
op::Constant::create(element::boolean, Shape{2, 3}, vector<char>{0, 1, 0, 0, 1, 1});
auto logical_not = make_shared<op::Not>(constant);
auto f = make_shared<Function>(logical_not, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Not>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<char>();
vector<char> values_expected{1, 0, 1, 1, 0, 0};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_equal) TEST(constant_folding, const_equal)
{ {
auto constant0 = auto constant0 =
......
...@@ -1218,11 +1218,12 @@ TEST(cpu_test, constant_unary_binary) ...@@ -1218,11 +1218,12 @@ TEST(cpu_test, constant_unary_binary)
auto logical_or = make_shared<op::Or>(i, j); auto logical_or = make_shared<op::Or>(i, j);
auto ceil = make_shared<op::Ceiling>(k); auto ceil = make_shared<op::Ceiling>(k);
auto floor = make_shared<op::Floor>(k); auto floor = make_shared<op::Floor>(k);
auto logical_not = make_shared<op::Not>(j);
auto func = make_shared<Function>( auto func = make_shared<Function>(
NodeVector{add, sub, mul, divn, min, max, absn, NodeVector{add, sub, mul, divn, min, max, absn, neg,
neg, sqrt, relu, sign, equal, not_equal, greater, sqrt, relu, sign, equal, not_equal, greater, greater_eq, less,
greater_eq, less, less_eq, logical_and, logical_or, ceil, floor}, less_eq, logical_and, logical_or, ceil, floor, logical_not},
ParameterVector{}); ParameterVector{});
auto func_error = make_shared<Function>(NodeVector{neg_sqrt}, ParameterVector{}); auto func_error = make_shared<Function>(NodeVector{neg_sqrt}, ParameterVector{});
...@@ -1253,6 +1254,7 @@ TEST(cpu_test, constant_unary_binary) ...@@ -1253,6 +1254,7 @@ TEST(cpu_test, constant_unary_binary)
ASSERT_EQ(count_ops_of_type<op::Or>(func), 0); ASSERT_EQ(count_ops_of_type<op::Or>(func), 0);
ASSERT_EQ(count_ops_of_type<op::Ceiling>(func), 0); ASSERT_EQ(count_ops_of_type<op::Ceiling>(func), 0);
ASSERT_EQ(count_ops_of_type<op::Floor>(func), 0); ASSERT_EQ(count_ops_of_type<op::Floor>(func), 0);
ASSERT_EQ(count_ops_of_type<op::Not>(func), 0);
//expected values //expected values
vector<int> add_expected{2, 4, 6, 8}; vector<int> add_expected{2, 4, 6, 8};
...@@ -1275,6 +1277,7 @@ TEST(cpu_test, constant_unary_binary) ...@@ -1275,6 +1277,7 @@ TEST(cpu_test, constant_unary_binary)
vector<char> or_expected{0, 1, 1, 1}; vector<char> or_expected{0, 1, 1, 1};
vector<float> ceil_expected{0.0f, 0.0f, -1.0f, 3.0f}; vector<float> ceil_expected{0.0f, 0.0f, -1.0f, 3.0f};
vector<float> floor_expected{-1.0f, 0.0f, -2.0f, 2.0f}; vector<float> floor_expected{-1.0f, 0.0f, -2.0f, 2.0f};
vector<char> not_expected{1, 0, 1, 0};
ASSERT_EQ(get_result_constant<int>(func, 0), add_expected); ASSERT_EQ(get_result_constant<int>(func, 0), add_expected);
ASSERT_EQ(get_result_constant<int>(func, 1), sub_expected); ASSERT_EQ(get_result_constant<int>(func, 1), sub_expected);
...@@ -1299,6 +1302,7 @@ TEST(cpu_test, constant_unary_binary) ...@@ -1299,6 +1302,7 @@ TEST(cpu_test, constant_unary_binary)
get_result_constant<float>(func, 19), ceil_expected, MIN_FLOAT_TOLERANCE_BITS)); get_result_constant<float>(func, 19), ceil_expected, MIN_FLOAT_TOLERANCE_BITS));
ASSERT_TRUE(test::all_close_f( ASSERT_TRUE(test::all_close_f(
get_result_constant<float>(func, 20), floor_expected, MIN_FLOAT_TOLERANCE_BITS)); get_result_constant<float>(func, 20), floor_expected, MIN_FLOAT_TOLERANCE_BITS));
ASSERT_EQ(get_result_constant<char>(func, 21), not_expected);
ASSERT_ANY_THROW(pass_manager.run_passes(func_error)); ASSERT_ANY_THROW(pass_manager.run_passes(func_error));
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment