Commit 5ece6de2 authored by Adam Procter's avatar Adam Procter Committed by Scott Cyphers

ConstantFolding for Not (#3326)

* CF for Sign, and extend element type capabilities for unary arithop CF

* CF for Ceiling and Floor

* Update CPU CF builders

* Update CPU CF builders

* CF for Not

* Add tests for new CPU CF folders

* Add tests for recently added CPU CF functors

* Add tests for non-CPU ceiling/floor CF

* Unit tests

* Add test for CPU folder
parent 09952c0b
......@@ -39,6 +39,7 @@
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/not.hpp"
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/or.hpp"
#include "ngraph/op/pad.hpp"
......@@ -72,6 +73,7 @@
#include "ngraph/runtime/reference/minimum.hpp"
#include "ngraph/runtime/reference/multiply.hpp"
#include "ngraph/runtime/reference/negate.hpp"
#include "ngraph/runtime/reference/not.hpp"
#include "ngraph/runtime/reference/not_equal.hpp"
#include "ngraph/runtime/reference/or.hpp"
#include "ngraph/runtime/reference/pad.hpp"
......@@ -725,8 +727,8 @@ bool is_supported_unary_op(std::shared_ptr<Node> n)
{
return std::dynamic_pointer_cast<op::Abs>(n) || std::dynamic_pointer_cast<op::Ceiling>(n) ||
std::dynamic_pointer_cast<op::Floor>(n) || std::dynamic_pointer_cast<op::Negative>(n) ||
std::dynamic_pointer_cast<op::Relu>(n) || std::dynamic_pointer_cast<op::Sign>(n) ||
std::dynamic_pointer_cast<op::Sqrt>(n);
std::dynamic_pointer_cast<op::Not>(n) || std::dynamic_pointer_cast<op::Relu>(n) ||
std::dynamic_pointer_cast<op::Sign>(n) || std::dynamic_pointer_cast<op::Sqrt>(n);
}
template <class T>
......@@ -778,6 +780,11 @@ shared_ptr<op::Constant> fold_constant_unary(shared_ptr<op::Constant> constant,
runtime::reference::negate<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
}
else if (std::dynamic_pointer_cast<op::Not>(unary))
{
runtime::reference::logical_not<T>(
constant->get_data_ptr<T>(), out_vec.data(), shape_size(out_shape));
}
else if (std::dynamic_pointer_cast<op::Relu>(unary))
{
runtime::reference::relu<T>(
......@@ -806,9 +813,11 @@ void pass::ConstantFolding::construct_constant_unary()
{
auto constant_label = make_shared<pattern::op::Label>(
element::f32, Shape{2, 4}, pattern::has_class<op::Constant>());
auto is_uea = pattern::has_class<op::util::UnaryElementwiseArithmetic>();
auto uea =
std::make_shared<pattern::op::Any>(constant_label, is_uea, NodeVector{constant_label});
auto is_ue = [](std::shared_ptr<Node> n) {
return (pattern::has_class<op::util::UnaryElementwiseArithmetic>()(n) ||
pattern::has_class<op::Not>()(n));
};
auto ue = std::make_shared<pattern::op::Any>(constant_label, is_ue, NodeVector{constant_label});
auto constant_unary_callback = [&, constant_label](pattern::Matcher& m) {
NGRAPH_DEBUG << "In callback for constant_unary_callback against node = "
......@@ -890,7 +899,7 @@ void pass::ConstantFolding::construct_constant_unary()
return true;
};
auto reshape_matcher = make_shared<pattern::Matcher>(uea, "ConstantFolding.ConstantUnary");
auto reshape_matcher = make_shared<pattern::Matcher>(ue, "ConstantFolding.ConstantUnary");
this->add_matcher(reshape_matcher, constant_unary_callback, PassProperty::REQUIRE_STATIC_SHAPE);
}
......
......@@ -550,6 +550,12 @@ namespace ngraph
BUILD_UNARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::sign);
}
template <>
NodeExecutorTy Builder::BUILDER_CF_DECL(ngraph::op::Not)
{
BUILD_UNARY_ELEMWISE_CF_FUNCTOR(runtime::cpu::kernel::logical_not);
}
#define TI(x) type_index(typeid(x))
BuildOpMap& GetGlobalBuildDispatcher()
......@@ -627,6 +633,7 @@ namespace ngraph
REGISTER_CF_BUILDER(And);
REGISTER_CF_BUILDER(Or);
REGISTER_CF_BUILDER(Sign);
REGISTER_CF_BUILDER(Not);
}
}
}
......@@ -459,6 +459,30 @@ TEST(constant_folding, const_concat)
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_not)
{
auto constant =
op::Constant::create(element::boolean, Shape{2, 3}, vector<char>{0, 1, 0, 0, 1, 1});
auto logical_not = make_shared<op::Not>(constant);
auto f = make_shared<Function>(logical_not, ParameterVector{});
pass::Manager pass_manager;
pass_manager.register_pass<pass::ConstantFolding>();
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Not>(f), 0);
ASSERT_EQ(count_ops_of_type<op::Constant>(f), 1);
auto new_const =
std::dynamic_pointer_cast<op::Constant>(f->get_results().at(0)->get_argument(0));
ASSERT_TRUE(new_const);
auto values_out = new_const->get_vector<char>();
vector<char> values_expected{1, 0, 1, 1, 0, 0};
ASSERT_EQ(values_expected, values_out);
}
TEST(constant_folding, const_equal)
{
auto constant0 =
......
......@@ -1218,11 +1218,12 @@ TEST(cpu_test, constant_unary_binary)
auto logical_or = make_shared<op::Or>(i, j);
auto ceil = make_shared<op::Ceiling>(k);
auto floor = make_shared<op::Floor>(k);
auto logical_not = make_shared<op::Not>(j);
auto func = make_shared<Function>(
NodeVector{add, sub, mul, divn, min, max, absn,
neg, sqrt, relu, sign, equal, not_equal, greater,
greater_eq, less, less_eq, logical_and, logical_or, ceil, floor},
NodeVector{add, sub, mul, divn, min, max, absn, neg,
sqrt, relu, sign, equal, not_equal, greater, greater_eq, less,
less_eq, logical_and, logical_or, ceil, floor, logical_not},
ParameterVector{});
auto func_error = make_shared<Function>(NodeVector{neg_sqrt}, ParameterVector{});
......@@ -1253,6 +1254,7 @@ TEST(cpu_test, constant_unary_binary)
ASSERT_EQ(count_ops_of_type<op::Or>(func), 0);
ASSERT_EQ(count_ops_of_type<op::Ceiling>(func), 0);
ASSERT_EQ(count_ops_of_type<op::Floor>(func), 0);
ASSERT_EQ(count_ops_of_type<op::Not>(func), 0);
//expected values
vector<int> add_expected{2, 4, 6, 8};
......@@ -1275,6 +1277,7 @@ TEST(cpu_test, constant_unary_binary)
vector<char> or_expected{0, 1, 1, 1};
vector<float> ceil_expected{0.0f, 0.0f, -1.0f, 3.0f};
vector<float> floor_expected{-1.0f, 0.0f, -2.0f, 2.0f};
vector<char> not_expected{1, 0, 1, 0};
ASSERT_EQ(get_result_constant<int>(func, 0), add_expected);
ASSERT_EQ(get_result_constant<int>(func, 1), sub_expected);
......@@ -1299,6 +1302,7 @@ TEST(cpu_test, constant_unary_binary)
get_result_constant<float>(func, 19), ceil_expected, MIN_FLOAT_TOLERANCE_BITS));
ASSERT_TRUE(test::all_close_f(
get_result_constant<float>(func, 20), floor_expected, MIN_FLOAT_TOLERANCE_BITS));
ASSERT_EQ(get_result_constant<char>(func, 21), not_expected);
ASSERT_ANY_THROW(pass_manager.run_passes(func_error));
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment